본문으로 건너뛰기

[flashinfer] FlashInfer MLA 커널 최적화: num_heads < 128 환경에서의 성능 극대화

PR 링크: flashinfer-ai/flashinfer#3309 상태: Merged | 변경: +1434 / -275

들어가며

최신 LLM 추론에서 Multi-Head Latent Attention (MLA)은 메모리 효율성을 극대화하는 핵심 기술입니다. 하지만 Blackwell 아키텍처와 같은 고성능 GPU에서 num_heads가 128 미만인 경우, GPU의 MMA(Matrix Multiply-Accumulate) 유닛 활용도가 떨어져 성능 저하가 발생합니다. 본 PR은 seqlen_q를 헤드 차원으로 폴딩(folding)하여 M-tile을 완전히 채움으로써, 특히 FP8 MLA 디코드 성능을 획기적으로 개선합니다.

코드 분석

1. flashinfer/cute_dsl/attention/monolithic/mla_decode.py: 폴딩 로직 도입

핵심 아이디어는 num_heads가 작을 때 seqlen_q의 일부를 헤드 차원으로 재배치하여 MMA 연산의 효율을 높이는 것입니다.

# Before/After: 폴딩 비율 계산 및 워크스페이스 크기 조정
fold_sq_ratio = BlackwellMultiHeadLatentAttentionForwardFP16.compute_fold_sq_ratio(
    H, q_len, mma_qk_tile_m
)
num_heads_eff = H * fold_sq_ratio
seq_len_q_eff = q_len // fold_sq_ratio

fold_sq_ratio를 계산하여 num_heads가 128을 채우지 못할 때 seqlen_q를 나누어 헤드 수를 늘립니다. 이는 MLA가 모든 헤드/쿼리 간에 KV를 독립적으로 공유한다는 특성을 활용한 것입니다.

2. flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py: 커널 설정 변경

커널 초기화 시 fold_sq 플래그를 통해 폴딩 경로를 활성화합니다.

# 폴딩 로직을 통한 텐서 레이아웃 재구성
if cutlass.const_expr(self.fold_sq):
    F = self.fold_sq_ratio
    def _fold_sq_4d(t):
        return cute.make_tensor(
            t.iterator,
            cute.make_layout(
                # ... 레이아웃 재구성 로직 ...
            )
        )

왜 이게 좋은가

이번 최적화는 Blackwell B200 환경에서 num_heads=16일 때 최대 3.6배 이상의 성능 향상을 보여줍니다.

  • 효율성: MMA 유닛의 M-tile(128)을 낭비 없이 활용하여 연산 밀도를 높였습니다.
  • 유연성: lse(Log-Sum-Exp) 버퍼 지원을 추가하여, 기존 trtllm-gen 스타일의 2D 레이아웃과 네이티브 3D 레이아웃을 모두 지원하도록 API를 개선했습니다.

일반적 교훈

  1. 하드웨어 아키텍처 이해: GPU의 MMA 타일 크기에 맞춰 데이터를 재배치(Folding)하는 것은 저차원 헤드 구조에서 필수적인 최적화입니다.
  2. API 유연성: 라이브러리 설계 시, 기존 생태계(trtllm 등)와의 호환성을 위해 다양한 입력 텐서 레이아웃을 수용하는 것이 중요합니다.

리뷰어 피드백

리뷰 과정에서 MTP(Multi-Token Prediction) 환경에서의 정합성 이슈가 논의되었으나, 이는 커널 자체의 문제가 아닌 통합 과정의 워크스페이스 공유 문제로 밝혀졌습니다. 또한, lse 버퍼의 2D/3D 레이아웃 지원을 통해 API 사용성을 크게 높였습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글