본문으로 건너뛰기

[vllm] ROCm AITER MHA 백엔드 재설계

PR 링크: vllm-project/vllm#25763 상태: Merged | 변경: +593/-275

들어가며

AMD ROCm 플랫폼에서 AITER(AI Tensor Engine for ROCm) MHA 백엔드가 완전히 재설계되었다. 기존 구현은 KV cache 레이아웃 변환에 비효율이 있었고, 컨텍스트 병렬 처리 지원이 부족했다. 새 설계는 Triton 커널 기반의 효율적인 cache gather와 SM 인식 스케줄링을 도입한다.

핵심 코드 분석

Triton 기반 Cache Gather 커널

# Before: 단순 layout 변환
def _vllm_layout_trans_kernel(k_buffer_ptr, v_buffer_ptr, ...)

# After: Context Parallel 지원 cache gather
@triton.jit
def cp_mha_gather_cache_kernel(
    key_cache_ptr,    # [num_blocks, page_size, num_head, head_size]
    value_cache_ptr,  # [num_blocks, page_size, num_head, head_size]
    key_ptr,          # [num_tokens, num_heads, head_size]
    value_ptr,        # [num_tokens, num_heads, head_size]
    block_table_ptr,  # [num_batches, max_block_num]
    cu_seqlens_kv_ptr,
    token_to_batch_ptr,
    ...
    DEQUANT: tl.constexpr,
    PAGE_SIZE: tl.constexpr,
    CACHE_FORMAT: tl.constexpr,
):

SM 인식 최적화

def block_size(x, head_dim):
    return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))

def num_programs(head_dim):
    return min(head_dim, get_num_sms())

GPU의 SM(Streaming Multiprocessor) 수를 고려하여 프로그램 수를 결정한다. 이는 GPU 리소스를 과도하게 할당하거나 부족하게 할당하는 것을 방지한다.

왜 이게 좋은가

  1. Context Parallel 지원: 긴 컨텍스트를 여러 GPU에 분산 처리 가능
  2. 양자화 호환: DEQUANT 플래그로 FP8 등 양자화된 KV cache에서도 동작
  3. SM 최적화: GPU 리소스를 효율적으로 활용하는 프로그램 스케줄링
  4. 코드 정리: 275줄 삭제로 레거시 코드를 정리하고 더 깔끔한 구현으로 대체

정리

AMD GPU 사용자에게 중요한 최적화다. AITER MHA 백엔드의 재설계로 ROCm 플랫폼에서의 어텐션 성능이 개선되고, context parallel 등 최신 기능과의 호환성이 확보되었다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글

관련 포스트

PR Analysis 의 다른글