[flashinfer] FlashInfer, 동적 토큰 페이지 커널 도입으로 TRTLLM-GEN GQA 성능 최적화
PR 링크: flashinfer-ai/flashinfer#3259 상태: Merged | 변경: +0 / -0
들어가며
최근 대규모 언어 모델(LLM)의 발전 속도는 눈부십니다. 이러한 모델들의 추론 성능을 극대화하기 위한 노력은 끊임없이 이루어지고 있으며, 그중 하나가 바로 메모리 접근 패턴을 최적화하는 것입니다. NVIDIA의 TensorRT-LLM(TRTLLM)은 LLM 추론을 위한 고성능 라이브러리로, FlashInfer는 TRTLLM의 핵심 컴포넌트인 Multi-Head Attention(MHA) 커널을 최적화하는 데 기여하고 있습니다. 이번 PR은 FlashInfer가 TRTLLM-GEN GQA(Grouped-Query Attention) 커널에 '동적 토큰 페이지(dynamic tokens-per-page)' 기능을 도입하여, 특히 페이지 크기가 128 이상일 때 추론 성능을 향상시키는 것을 목표로 합니다.
이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 이러한 변경이 왜 성능 향상으로 이어지는지, 그리고 어떤 기술적 교훈을 얻을 수 있는지 살펴보겠습니다.
코드 변경 분석
이번 PR의 핵심은 flashinfer/artifacts.py 파일의 아티팩트 경로 및 체크섬 업데이트와 include/flashinfer/trtllm/fmha/fmhaKernels.cuh 파일의 커널 로직 수정입니다.
1. 아티팩트 경로 및 체크섬 업데이트 (flashinfer/artifacts.py)
가장 먼저 눈에 띄는 변경은 TRTLLM_GEN_FMHA의 아티팩트 경로와 체크섬 해시가 업데이트된 것입니다. 이는 새로운 CUDA 커널(cubins)이 빌드되어 기존 경로에서 사용할 수 없게 되었음을 의미합니다. 새로운 커널은 성능 개선이나 새로운 기능 지원을 포함할 수 있습니다.
Before:
- TRTLLM_GEN_FMHA: str = "1d876ee612888821b168c25ffa75a9dcbb963aaa/fmha/trtllm-gen/"
+ TRTLLM_GEN_FMHA: str = "5d4df6c2647e860992d1cc57ced05204b55f3787/fmha/trtllm-gen/"
Before:
- "1abeea012a8779c6df5b84332fad43c6cfc3b257fe5ab883c8ea501464010d16"
+ "681e69c9c0215b4780eaf92f897c2dc94285a9143b90f765085eba83af5afa5b"
이러한 변경은 최신 성능 최적화가 적용된 커널을 사용하기 위한 필수적인 과정입니다.
2. 동적 토큰 페이지 커널 도입 (include/flashinfer/trtllm/fmha/fmhaKernels.cuh)
이 PR의 핵심적인 기능 개선은 fmhaKernels.cuh 파일에 구현되었습니다. 특히 TllmGenFmhaKernel 클래스에 동적 토큰 페이지 기능을 위한 로직이 추가되었습니다.
새로운 상수 정의:
+ static constexpr int kDynamicNumTokensPerPageThreshold = 128;
+ static constexpr int kDynamicNumTokensPerPageKernelKey = 128;
kDynamicNumTokensPerPageThreshold는 동적 토큰 페이지 커널을 활성화하기 위한 최소 페이지 크기(128)를 정의합니다. kDynamicNumTokensPerPageKernelKey는 해당 커널에서 사용할 기본 페이지 크기입니다.
numTokensPerPage 검증 로직 수정:
기존에는 numTokensPerPage가 반드시 2의 거듭제곱이어야 했지만, 이제는 0 (사용하지 않음) 또는 2의 거듭제곱이면 허용됩니다.
Before:
- // The numTokensPerPage must be power of 2.
- FLASHINFER_CHECK((numTokensPerPage & (numTokensPerPage - 1)) == 0,
- "The numTokensPerPage must be power of 2.");
+ // The numTokensPerPage must be 0 (unused for non-paged kernels) or power of 2.
+ FLASHINFER_CHECK(numTokensPerPage == 0 || ((numTokensPerPage & (numTokensPerPage - 1)) == 0),
+ "The numTokensPerPage must be 0 or power of 2.");
TllmGenFmhaKernel::selectKernel 로직 변경:
selectKernel 함수는 입력 파라미터에 따라 최적의 커널을 선택하는 역할을 합니다. 이 함수 내에서 selectNumTokensPerPage 함수가 호출되어 numTokensPerPage 및 mDynamicNumTokensPerPage 플래그를 설정합니다.
새로운 useDynamicNumTokensPerPage 함수:
+ inline bool useDynamicNumTokensPerPage(RunnerParams const& params) const {
+ return isPagedKv(params.mQkvLayout) && !params.mSparseMla && params.mNumHeadsQPerKv > 1 &&
+ params.mHeadDimQk == params.mHeadDimV &&
+ params.mNumTokensPerPage >= kDynamicNumTokensPerPageThreshold;
+ }
이 함수는 페이징된 KV 캐시 레이아웃을 사용하고, 희소 MLA가 아니며, 헤드 차원이 동일하고, 페이지 크기가 임계값 이상일 때 동적 토큰 페이지 커널을 사용하도록 결정합니다.
새로운 selectNumTokensPerPage 함수:
+ void selectNumTokensPerPage(RunnerParams const& params,
+ SelectKernelParams& selectKernelParams) const {
+ selectKernelParams.mDynamicNumTokensPerPage = false;
+ if (params.mSparseMla) {
+ // SparseMla kernels use a fixed numTokensPerPage = 1.
+ selectKernelParams.mNumTokensPerPage = 1;
+ } else if (!isPagedKv(params.mQkvLayout)) {
+ // NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels.
+ selectKernelParams.mNumTokensPerPage = 0;
+ } else if (useDynamicNumTokensPerPage(params)) {
+ FLASHINFER_CHECK((params.mNumTokensPerPage & (params.mNumTokensPerPage - 1)) == 0,
+ "Dynamic numTokensPerPage requires a power-of-2 page size, got %d.",
+ params.mNumTokensPerPage);
+ selectKernelParams.mDynamicNumTokensPerPage = true;
+ selectKernelParams.mNumTokensPerPage = kDynamicNumTokensPerPageKernelKey;
+ } else {
+ selectKernelParams.mNumTokensPerPage = params.mNumTokensPerPage;
+ }
+ }
이 함수는 위에서 정의된 useDynamicNumTokensPerPage 로직을 기반으로 selectKernelParams의 mNumTokensPerPage와 mDynamicNumTokensPerPage를 설정합니다. 동적 토큰 페이지 커널이 선택되면 mDynamicNumTokensPerPage가 true로 설정되고, mNumTokensPerPage는 kDynamicNumTokensPerPageKernelKey 값으로 고정됩니다.
MLA 생성 커널 로직 확장:
isMlaGenKernel 함수가 확장되어 더 많은 헤드/차원 조합을 MLA 생성 커널로 인식하도록 변경되었습니다.
Before:
- return params.mHeadDimQk == 576 && params.mHeadDimV == 512;
+ return (params.mHeadDimQk == 576 && params.mHeadDimV == 512) ||
+ (params.mHeadDimQk == 320 && params.mHeadDimV == 256);
슬라이딩 윈도우 및 청크드 어텐션 지원 제한:
리뷰 과정에서 발견된 중요한 변경 사항으로, MLA 생성 커널에서 슬라이딩 윈도우 및 청크드 어텐션 지원이 제거되었습니다. 이는 MLA 커널이 Dense 마스크를 가정하고 동작하기 때문입니다.
+ FLASHINFER_CHECK(
+ params.mMaxSeqLenKv <= params.mAttentionWindowSize &&
+ params.mChunkedAttentionSize == INT_MAX,
+ "TRTLLM-GEN MLA generation does not support sliding-window or chunked attention.");
이 FLASHINFER_CHECK는 MLA 생성 커널 사용 시 슬라이딩 윈도우나 청크드 어텐션이 활성화되어 있으면 오류를 발생시켜, 잘못된 결과가 반환되는 것을 방지합니다. 이는 'fail-loud' 원칙을 따르는 좋은 설계입니다.
3. 테스트 및 기타 변경사항
tests/attention/test_attention_sink_blackwell.py: FP16 Blackwell 어텐션 싱크 컨텍스트 허용 오차를 완화하여 관찰된 정밀도 노이즈를 맞춥니다. 이는 하드웨어의 미세한 차이를 수용하여 테스트 안정성을 높입니다.tests/attention/test_trtllm_gen_attention.py: 동적 토큰 페이지 테스트를 위한 새로운 테스트 케이스가 추가되었습니다. 이는 새로운 기능의 정확성을 검증하는 데 필수적입니다.include/flashinfer/trtllm/fmha/kernelParams.h:tmaKSlidingWindowKvPool_및ptrSparseMlaTopKLens와 같은 새로운 멤버 변수가 추가되었습니다. 이는 새로운 커널 기능(예: 희소 MLA 커널)을 지원하기 위한 것입니다.
왜 이게 좋은가?
1. 동적 토큰 페이지 커널의 이점
이 PR의 핵심은 '동적 토큰 페이지' 커널의 도입입니다. LLM 추론 시, 특히 KV 캐시 관리는 성능에 큰 영향을 미칩니다. KV 캐시는 이전 토큰들의 Key와 Value 벡터를 저장하는데, 이 캐시를 효율적으로 관리하는 것이 중요합니다.
기존의 고정된 페이지 크기 방식은 다음과 같은 비효율을 야기할 수 있습니다:
- 메모리 낭비: 시퀀스 길이가 페이지 크기보다 훨씬 짧을 경우, 페이지의 일부만 사용되고 나머지는 낭비됩니다.
- 불필요한 오버헤드: 페이지 크기가 너무 작으면 페이지 테이블 관리에 오버헤드가 발생할 수 있습니다.
동적 토큰 페이지 커널은 페이지 크기를 동적으로 조정하거나, 더 효율적인 방식으로 페이지를 활용할 수 있게 합니다. 특히 페이지 크기가 128 이상일 때 이 커널을 사용하도록 설정된 것은, 해당 크기 이상에서 페이지 기반 접근 방식이 더 큰 이점을 보이기 때문일 수 있습니다. 이는 GPU 메모리 대역폭 활용도를 높이고, 불필요한 메모리 접근을 줄여 추론 속도를 향상시킬 수 있습니다.
성능 향상:
PR 설명에 따르면, 이 변경은 특히 페이지 크기가 128 이상일 때 GQA 디코드 및 프리필 성능을 개선합니다. 구체적인 성능 수치는 제공되지 않았지만, 이러한 최적화는 일반적으로 수 밀리초에서 수십 밀리초의 응답 시간 단축으로 이어질 수 있으며, 이는 LLM 서비스의 처리량을 크게 향상시킵니다.
2. MLA 생성 커널 지원 확장
isMlaGenKernel 함수의 확장은 더 넓은 범위의 모델 구성에서 MLA(Multi-head Linear Attention) 커널을 활용할 수 있게 합니다. MLA는 기존의 Scaled Dot-Product Attention보다 계산 복잡도가 낮아 성능 향상에 기여할 수 있습니다. 더 많은 헤드 및 차원 조합을 지원함으로써, 더 많은 사용 사례에서 MLA의 이점을 누릴 수 있게 되었습니다.
3. 'Fail-loud' 원칙 준수
MLA 생성 커널에서 슬라이딩 윈도우 및 청크드 어텐션 지원을 명시적으로 제거하고 FLASHINFER_CHECK를 사용한 것은 매우 좋은 설계입니다. 이는 사용자가 의도치 않게 지원되지 않는 기능을 사용하려 할 때, 오류 메시지를 통해 명확하게 인지하고 수정할 수 있도록 합니다. 이는 디버깅 시간을 단축하고, 잠재적인 버그를 사전에 방지하는 데 큰 도움이 됩니다.
4. 테스트 안정성 향상
test_attention_sink_blackwell.py에서 FP16 허용 오차를 조정한 것은 실제 하드웨어 환경에서의 미묘한 정밀도 차이를 수용하기 위한 실용적인 조치입니다. 이는 테스트의 신뢰성을 높이고, 불필요한 테스트 실패를 줄여 개발 프로세스를 원활하게 합니다.
일반적인 교훈
- 하드웨어 특성 활용: GPU 아키텍처(예: Blackwell)와 메모리 접근 패턴(페이지 기반 KV 캐시)의 특성을 깊이 이해하고 이를 커널 최적화에 활용하는 것이 중요합니다.
- 동적 커널 선택: 다양한 입력 조건(페이지 크기, 헤드 차원 등)에 따라 최적의 커널을 동적으로 선택하는 메커니즘은 LLM 추론 성능을 극대화하는 핵심 요소입니다.
- 'Fail-loud' 원칙: 지원되지 않는 기능 조합이나 잘못된 사용 사례에 대해서는 명확한 오류 메시지를 통해 사용자에게 알리는 것이 장기적으로 더 나은 개발 경험과 시스템 안정성을 보장합니다.
- 지속적인 아티팩트 관리: 최적화된 커널은 종종 새로운 바이너리(cubins)로 제공됩니다. 이러한 아티팩트의 경로와 무결성(체크섬)을 최신 상태로 유지하는 것은 라이브러리 유지보수의 중요한 부분입니다.
결론
이번 FlashInfer PR은 TRTLLM-GEN GQA 커널에 동적 토큰 페이지 기능을 도입함으로써 LLM 추론 성능을 한 단계 끌어올렸습니다. 또한, MLA 생성 커널의 지원 범위를 넓히고, 오류 처리 방식을 개선하는 등 다방면에 걸친 최적화를 이루었습니다. 이러한 개선은 더 빠르고 효율적인 LLM 추론을 가능하게 하며, AI 모델의 실질적인 활용도를 높이는 데 기여할 것입니다.
참고 자료
- https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fmha/fmhaKernels.cuh
- https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
- https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fmha/kernelParams.h
- https://github.com/flashinfer-ai/flashinfer/blob/main/tests/attention/test_attention_sink_blackwell.py
- https://github.com/flashinfer-ai/flashinfer/blob/main/tests/attention/test_trtllm_gen_attention.py
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [flashinfer] FlashInfer, CUDA 그래프 호환성을 높이고 성능을 최적화하다: TRT-LLM FMHA v2 통합 및 불필요한 H2D 제거
- [flashinfer] FlashInfer, FP8 지원으로 장문 컨텍스트 추론 성능을 극적으로 향상시키다
- [flashinfer] FlashInfer, CuTe DSL 기반 FMHA 커널 통합으로 사전 생성(Prefill) 성능 극대화
- [sglang] SGLang: MiniMax-M2.5 MoE 모델을 위한 FP8 FlashInfer TRT-LLM 라우팅 최적화
- [flashinfer] FlashInfer의 Per-token NVFP4 Quantization 커널 최적화 분석
PR Analysis 의 다른글
- 이전글 [cpython] CPython inspect.getattr_static 성능 개선: 일반적인 메타클래스 사례 최적화
- 현재글 : [flashinfer] FlashInfer, 동적 토큰 페이지 커널 도입으로 TRTLLM-GEN GQA 성능 최적화
- 다음글 [sglang] SGLang의 Breakable CUDA Graph 최적화: 배치 사이즈 제한 극복하기
댓글