[flashinfer] FlashInfer MLA 커널 최적화: num_heads < 128 환경에서의 성능 극대화
PR 링크: flashinfer-ai/flashinfer#3309 상태: Merged | 변경: +1434 / -275
들어가며
최신 LLM 추론에서 Multi-Head Latent Attention (MLA)은 메모리 효율성을 극대화하는 핵심 기술입니다. 하지만 Blackwell 아키텍처와 같은 고성능 GPU에서 num_heads가 128 미만인 경우, GPU의 MMA(Matrix Multiply-Accumulate) 유닛 활용도가 떨어져 성능 저하가 발생합니다. 본 PR은 seqlen_q를 헤드 차원으로 폴딩(folding)하여 M-tile을 완전히 채움으로써, 특히 FP8 MLA 디코드 성능을 획기적으로 개선합니다.
코드 분석
1. flashinfer/cute_dsl/attention/monolithic/mla_decode.py: 폴딩 로직 도입
핵심 아이디어는 num_heads가 작을 때 seqlen_q의 일부를 헤드 차원으로 재배치하여 MMA 연산의 효율을 높이는 것입니다.
# Before/After: 폴딩 비율 계산 및 워크스페이스 크기 조정
fold_sq_ratio = BlackwellMultiHeadLatentAttentionForwardFP16.compute_fold_sq_ratio(
H, q_len, mma_qk_tile_m
)
num_heads_eff = H * fold_sq_ratio
seq_len_q_eff = q_len // fold_sq_ratio
fold_sq_ratio를 계산하여 num_heads가 128을 채우지 못할 때 seqlen_q를 나누어 헤드 수를 늘립니다. 이는 MLA가 모든 헤드/쿼리 간에 KV를 독립적으로 공유한다는 특성을 활용한 것입니다.
2. flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py: 커널 설정 변경
커널 초기화 시 fold_sq 플래그를 통해 폴딩 경로를 활성화합니다.
# 폴딩 로직을 통한 텐서 레이아웃 재구성
if cutlass.const_expr(self.fold_sq):
F = self.fold_sq_ratio
def _fold_sq_4d(t):
return cute.make_tensor(
t.iterator,
cute.make_layout(
# ... 레이아웃 재구성 로직 ...
)
)
왜 이게 좋은가
이번 최적화는 Blackwell B200 환경에서 num_heads=16일 때 최대 3.6배 이상의 성능 향상을 보여줍니다.
- 효율성: MMA 유닛의 M-tile(128)을 낭비 없이 활용하여 연산 밀도를 높였습니다.
- 유연성:
lse(Log-Sum-Exp) 버퍼 지원을 추가하여, 기존trtllm-gen스타일의 2D 레이아웃과 네이티브 3D 레이아웃을 모두 지원하도록 API를 개선했습니다.
일반적 교훈
- 하드웨어 아키텍처 이해: GPU의 MMA 타일 크기에 맞춰 데이터를 재배치(Folding)하는 것은 저차원 헤드 구조에서 필수적인 최적화입니다.
- API 유연성: 라이브러리 설계 시, 기존 생태계(trtllm 등)와의 호환성을 위해 다양한 입력 텐서 레이아웃을 수용하는 것이 중요합니다.
리뷰어 피드백
리뷰 과정에서 MTP(Multi-Token Prediction) 환경에서의 정합성 이슈가 논의되었으나, 이는 커널 자체의 문제가 아닌 통합 과정의 워크스페이스 공유 문제로 밝혀졌습니다. 또한, lse 버퍼의 2D/3D 레이아웃 지원을 통해 API 사용성을 크게 높였습니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
- https://docs.nvidia.com/cuda/cutlass/index.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [flashinfer] FlashInfer FP8 KV-Cache Prefill 성능 최적화: Repacking 기법을 통한 오버헤드 제거
- [triton] [Triton] Persistent Matmul 성능을 13% 향상시킨 정교한 Shared Memory 계산 기법 분석
- [vllm] vLLM 기술 딥다이브: CUTLASS를 활용한 NVFP4 Linear 커널의 Batch Invariance 최적화
- [flashinfer] FlashInfer의 DeepSeek V4 Sparse MLA 최적화 분석
- [vllm] Blackwell을 위한 새로운 MLA 백엔드: TOKENSPEED_MLA 분석 (DeepSeek R1 최적화)
PR Analysis 의 다른글
- 이전글 [axolotl] Axolotl MoE 모델 최적화: Tiled-MLP 도입 및 FSDP2 통합으로 성능 극대화
- 현재글 : [flashinfer] FlashInfer MLA 커널 최적화: num_heads < 128 환경에서의 성능 극대화
- 다음글 [onnxruntime] ONNX Runtime의 CPU GQA 최적화: Flash Attention과 Flash Decoding 도입
댓글