[flashinfer] FlashInfer, SM120 GPU를 위한 희소 MLA 커널 추가로 LLM 추론 속도 향상
PR 링크: flashinfer-ai/flashinfer#3395 상태: Merged | 변경: +10727 / -99
들어가며
최근 대규모 언어 모델(LLM)의 발전 속도가 빨라지면서, 모델의 추론 성능을 극대화하는 것은 매우 중요한 과제가 되었습니다. 특히 GPU 아키텍처의 특성을 최대한 활용하는 것은 성능 향상의 핵심입니다. NVIDIA의 최신 GPU 아키텍처인 SM120 (Ampere 아키텍처의 일부)는 이전 세대와 다른 메모리 계층 구조와 연산 능력을 가지고 있습니다. 이러한 변화에 발맞춰, FlashInfer 라이브러리는 SM120 GPU에 최적화된 새로운 희소 Multi-Head Attention (MLA) 커널을 도입하여 LLM 추론 성능을 한 단계 끌어올렸습니다.
이번 PR은 SM120 및 SM121 GPU에서 사용되는 DSv4 (d_qk=512) 및 DSv3.2 / GLM-5.1 (d_qk=576) 모델을 위한 희소 MLA 커널을 추가합니다. 이 개선 사항은 특히 paged attention 시나리오에서 디코딩(decode) 및 프리필(prefill) 단계 모두에서 성능 향상을 가져옵니다. 사용자는 기존의 flashinfer.mla API를 그대로 사용하면서, 백엔드에서 자동으로 최적화된 SM120 커널을 활용할 수 있게 됩니다.
코드 분석
이번 PR의 핵심은 SM120 아키텍처의 특성을 활용하는 새로운 CUDA 커널을 개발하고, 이를 FlashInfer의 기존 파이프라인에 통합하는 것입니다. 주요 변경 사항은 다음과 같습니다.
1. benchmarks/bench_sparse_mla_sm120.py (새 파일)
이 파일은 새로 추가된 SM120 희소 MLA 커널의 성능을 측정하기 위한 마이크로벤치마크를 포함합니다. DSv4 및 DSv3.2 모델의 다양한 시나리오 (단일 캐시, 듀얼 캐시, 디코딩, 프리필)에 대한 성능 지표 (지연 시간, KV 대역폭, TFLOPs)를 측정하고 보고합니다.
- DSv4 프리필 (단일 캐시): 다양한 배치 크기 (NH), top-k 값, 시퀀스 길이 (T)에 대해 지연 시간, KV GB/s, TFLOPs를 측정합니다.
- DSv4 프리필 (듀얼 캐시): 메인 캐시와 보조 캐시를 함께 사용하는 시나리오를 측정합니다. 이는 더 복잡한 메모리 접근 패턴을 가지며, 성능 향상의 잠재력이 큽니다.
- DSv4 디코딩: T=1인 디코딩 단계에서의 지연 시간을 측정합니다.
- DSv3.2 / GLM-5.1: 다른 모델 아키텍처에 대한 성능을 측정합니다.
이 벤치마크 파일은 SM120 커널의 실제 성능을 정량적으로 보여주는 중요한 역할을 합니다.
2. flashinfer/mla/_sparse_mla_sm120.py (새 파일)
이 파일은 SM120 GPU를 위한 희소 MLA 커널의 핵심 로직을 구현합니다. _SparseMLAPagedAttentionRunner 클래스는 LSE (LogSumExp) 및 디코딩 스크래치 버퍼를 관리하는 내부 헬퍼 역할을 합니다. 또한, sparse_mla_sm120_decode_dsv4와 같은 함수들은 JIT 모듈 및 AutoTuner의 전술 캐시를 활용하여 커널을 직접 래핑합니다.
이 파일의 가장 중요한 부분은 flashinfer.jit.mla.gen_sparse_mla_sm120_module() 함수로, 새로운 SM120 희소 MLA CUDA 소스를 JIT/AOT 빌드 시스템에 통합하는 역할을 합니다. 이를 통해 Python API 호출이 실제 최적화된 CUDA 커널로 연결됩니다.
3. flashinfer/mla/_core.py (수정)
이 파일은 기존 MLA API의 동작을 SM 아키텍처에 따라 분기하는 로직을 포함합니다. 이번 PR에서는 SM120/SM121 지원을 추가하기 위해 기존 API의 동작 방식을 수정하고, 문서화(docstring)를 개선했습니다.
- 기존 API 확장:
trtllm_batch_decode_sparse_mla_dsv4와 같은 기존 API 엔드포인트가 SM120/SM121에서도 동작하도록 확장되었습니다. SM100/103은 기존 TRTLLM-GEN 경로를 사용하고, SM120/121은 새로운 희소 백엔드를 사용합니다. - 문서화 개선: SM100/103과 SM120/121에서의 API 동작 방식 차이, KV 캐시 레이아웃, 지원되는 데이터 타입 등에 대한 문서화가 명확하게 개선되었습니다. 이는 사용자가 API를 올바르게 이해하고 활용하는 데 도움을 줍니다.
리뷰어 bkryu는 SM 아키텍처에 따라 고유한 커널이 사용된다면 backend 파라미터가 불필요할 수 있다는 점을 지적했으며, API 문서화를 사용자 관점에서 더 명확하게 작성할 것을 제안했습니다. 이러한 피드백은 API의 사용성을 높이는 데 기여했습니다.
4. flashinfer/jit/sparse_mla_sm120.py (새 파일)
이 파일은 SM120 희소 MLA 커널을 위한 JIT (Just-In-Time) 컴파일 모듈을 정의합니다. gen_sparse_mla_sm120_module 함수는 CUDA 소스 코드를 컴파일하고, 이를 Python에서 사용할 수 있는 객체로 만듭니다. 이는 동적으로 커널을 생성하고 최적화하는 FlashInfer의 JIT 컴파일 전략의 일부입니다.
리뷰어 bkryu는 이 파일의 내용을 flashinfer/jit/mla.py로 통합하는 것을 제안했으나, SM120 특화 커널의 독립성을 유지하기 위해 별도 파일로 관리하는 현재 방식도 합리적인 선택으로 보입니다.
5. flashinfer/mla/trtllm.py (수정)
trtllm_batch_decode_with_kv_cache_mla 함수는 기존 MLA 디코딩 API를 확장하여 SM120/SM121을 위한 희소 백엔드를 지원합니다. backend='auto' 또는 backend='sparse' 옵션을 통해 SM120/SM121에서 새로운 희소 백엔드를 선택할 수 있습니다. 또한, kv_scale_format 파라미터를 통해 Packed KV 캐시의 스케일 형식을 제어할 수 있습니다.
이 수정은 사용자가 기존 API를 변경 없이 사용하면서도 최신 하드웨어의 성능을 활용할 수 있도록 하는 데 중점을 둡니다.
왜 이게 좋은가?
이번 PR은 여러 측면에서 중요한 개선을 이루었습니다:
- SM120/SM121 GPU 최적화: 최신 NVIDIA GPU 아키텍처인 SM120/SM121의 특성을 활용하는 전용 CUDA 커널을 도입함으로써, 이전 세대 GPU 대비 상당한 성능 향상을 달성했습니다. 특히 DSv4 및 DSv3.2 모델에서 지연 시간 감소와 처리량 증가를 확인할 수 있습니다.
- 희소 MLA 커널 도입: 희소 어텐션 메커니즘은 LLM에서 발생하는 계산량의 상당 부분을 차지하는 어텐션 계산을 효율적으로 만듭니다. 특히 KV 캐시의 일부 토큰만 고려하는 top-k 샘플링과 같은 기법과 결합될 때, 계산량을 크게 줄일 수 있습니다. SM120용 희소 MLA 커널은 이러한 희소성을 더욱 효과적으로 활용합니다.
- Paged Attention 성능 향상: Paged Attention은 동적으로 할당되는 KV 캐시를 효율적으로 관리하는 기법입니다. 이번 PR에서 추가된 커널들은 SM120 아키텍처에서 paged attention의 성능을 최적화하여, 특히 긴 시퀀스 처리 시 메모리 접근 효율성을 높였습니다.
- 성능 향상 수치: PR 설명에 따르면, RTX PRO 6000 Blackwell Server Edition (SM120a)에서 DSv4 듀얼 캐시 프리필 시나리오에서 기존 대비 1.19-1.32배의 성능 향상이 관찰되었습니다. 이는 실제 서비스 환경에서 LLM 추론 속도를 크게 단축시킬 수 있는 중요한 결과입니다.
- API 호환성 유지: 사용자는 기존의
flashinfer.mlaAPI를 그대로 사용할 수 있습니다. SM 아키텍처에 따른 최적화는 내부적으로 처리되므로, API 변경 없이도 성능 향상을 누릴 수 있습니다. 이는 라이브러리의 사용성을 크게 높이는 요소입니다.
일반적 교훈:
- 하드웨어 특성 활용: 새로운 하드웨어 아키텍처가 출시될 때마다, 해당 아키텍처의 고유한 기능과 메모리 계층 구조를 이해하고 이를 활용하는 최적화된 커널을 개발하는 것이 성능 향상의 핵심입니다.
- JIT 컴파일 및 AOT 빌드: FlashInfer와 같이 JIT 컴파일 또는 AOT 빌드 파이프라인을 활용하면, 다양한 하드웨어 및 사용 사례에 맞춰 동적으로 최적화된 코드를 생성하여 성능을 극대화할 수 있습니다.
- 명확한 API 설계 및 문서화: 복잡한 내부 구현을 추상화하여 사용자에게는 간결하고 명확한 API를 제공하는 것이 중요합니다. 또한, 다양한 하드웨어 지원 시나리오에 대한 상세한 문서화는 사용자의 이해를 돕고 라이브러리 채택률을 높입니다.
References
- torch.compile - While not directly used in this PR, FlashInfer's JIT compilation strategy shares conceptual similarities with PyTorch's compile feature in aiming to optimize Python code execution via underlying optimized kernels.
- NVIDIA Ampere Architecture Whitepaper - Provides background on the SM120 architecture features relevant to this optimization.
- FlashInfer MLA API Documentation - The public API surface for Multi-Head Attention operations in FlashInfer.
- FlashInfer JIT Compilation - Explains FlashInfer's Just-In-Time compilation mechanism for generating optimized kernels.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-ampere-architecture-whitepaper.pdf
- https://github.com/flashinfer-ai/flashinfer/blob/main/python/flashinfer/mla/__init__.py
- https://github.com/flashinfer-ai/flashinfer/blob/main/python/flashinfer/jit.py
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [flashinfer] FlashInfer Unified MoE API: NVFP4 백엔드 통합 및 자동 튜닝 최적화
- [flashinfer] FlashInfer FP8 KV-Cache Prefill 성능 최적화: Repacking 기법을 통한 오버헤드 제거
- [flashinfer] FlashInfer의 DeepSeek V4 Sparse MLA 최적화 분석
- [flashinfer] FlashInfer, 동적 토큰 페이지 커널 도입으로 TRTLLM-GEN GQA 성능 최적화
- [flashinfer] FlashInfer의 Per-token NVFP4 Quantization 커널 최적화 분석
PR Analysis 의 다른글
- 이전글 [sglang] SGLang NPU 성능 최적화: Disaggregation 모드 개선 분석
- 현재글 : [flashinfer] FlashInfer, SM120 GPU를 위한 희소 MLA 커널 추가로 LLM 추론 속도 향상
- 다음글 [sglang] SGLang: DeepSeek-R1 FP8 GEMM 성능 회귀 문제 해결 및 최적화
댓글