[vllm] ROCm AITER MHA 백엔드 재설계
PR 링크: vllm-project/vllm#25763 상태: Merged | 변경: +593/-275
들어가며
AMD ROCm 플랫폼에서 AITER(AI Tensor Engine for ROCm) MHA 백엔드가 완전히 재설계되었다. 기존 구현은 KV cache 레이아웃 변환에 비효율이 있었고, 컨텍스트 병렬 처리 지원이 부족했다. 새 설계는 Triton 커널 기반의 효율적인 cache gather와 SM 인식 스케줄링을 도입한다.
핵심 코드 분석
Triton 기반 Cache Gather 커널
# Before: 단순 layout 변환
def _vllm_layout_trans_kernel(k_buffer_ptr, v_buffer_ptr, ...)
# After: Context Parallel 지원 cache gather
@triton.jit
def cp_mha_gather_cache_kernel(
key_cache_ptr, # [num_blocks, page_size, num_head, head_size]
value_cache_ptr, # [num_blocks, page_size, num_head, head_size]
key_ptr, # [num_tokens, num_heads, head_size]
value_ptr, # [num_tokens, num_heads, head_size]
block_table_ptr, # [num_batches, max_block_num]
cu_seqlens_kv_ptr,
token_to_batch_ptr,
...
DEQUANT: tl.constexpr,
PAGE_SIZE: tl.constexpr,
CACHE_FORMAT: tl.constexpr,
):
SM 인식 최적화
def block_size(x, head_dim):
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
def num_programs(head_dim):
return min(head_dim, get_num_sms())
GPU의 SM(Streaming Multiprocessor) 수를 고려하여 프로그램 수를 결정한다. 이는 GPU 리소스를 과도하게 할당하거나 부족하게 할당하는 것을 방지한다.
왜 이게 좋은가
- Context Parallel 지원: 긴 컨텍스트를 여러 GPU에 분산 처리 가능
- 양자화 호환:
DEQUANT플래그로 FP8 등 양자화된 KV cache에서도 동작 - SM 최적화: GPU 리소스를 효율적으로 활용하는 프로그램 스케줄링
- 코드 정리: 275줄 삭제로 레거시 코드를 정리하고 더 깔끔한 구현으로 대체
정리
AMD GPU 사용자에게 중요한 최적화다. AITER MHA 백엔드의 재설계로 ROCm 플랫폼에서의 어텐션 성능이 개선되고, context parallel 등 최신 기능과의 호환성이 확보되었다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] Tutorials: 벤치마크 결과 테이블에 단위(units) 표시 추가
- 현재글 : [vllm] ROCm AITER MHA 백엔드 재설계
- 다음글 [pydantic-ai] Validation 에러 재시도 메시지 개선 — Markdown 코드 블록 포맷
댓글