본문으로 건너뛰기

[flashinfer] FlashInfer, 동적 토큰 페이지 커널 도입으로 TRTLLM-GEN GQA 성능 최적화

PR 링크: flashinfer-ai/flashinfer#3259 상태: Merged | 변경: +0 / -0

들어가며

최근 대규모 언어 모델(LLM)의 발전 속도는 눈부십니다. 이러한 모델들의 추론 성능을 극대화하기 위한 노력은 끊임없이 이루어지고 있으며, 그중 하나가 바로 메모리 접근 패턴을 최적화하는 것입니다. NVIDIA의 TensorRT-LLM(TRTLLM)은 LLM 추론을 위한 고성능 라이브러리로, FlashInfer는 TRTLLM의 핵심 컴포넌트인 Multi-Head Attention(MHA) 커널을 최적화하는 데 기여하고 있습니다. 이번 PR은 FlashInfer가 TRTLLM-GEN GQA(Grouped-Query Attention) 커널에 '동적 토큰 페이지(dynamic tokens-per-page)' 기능을 도입하여, 특히 페이지 크기가 128 이상일 때 추론 성능을 향상시키는 것을 목표로 합니다.

이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 이러한 변경이 왜 성능 향상으로 이어지는지, 그리고 어떤 기술적 교훈을 얻을 수 있는지 살펴보겠습니다.

코드 변경 분석

이번 PR의 핵심은 flashinfer/artifacts.py 파일의 아티팩트 경로 및 체크섬 업데이트와 include/flashinfer/trtllm/fmha/fmhaKernels.cuh 파일의 커널 로직 수정입니다.

1. 아티팩트 경로 및 체크섬 업데이트 (flashinfer/artifacts.py)

가장 먼저 눈에 띄는 변경은 TRTLLM_GEN_FMHA의 아티팩트 경로와 체크섬 해시가 업데이트된 것입니다. 이는 새로운 CUDA 커널(cubins)이 빌드되어 기존 경로에서 사용할 수 없게 되었음을 의미합니다. 새로운 커널은 성능 개선이나 새로운 기능 지원을 포함할 수 있습니다.

Before:

-    TRTLLM_GEN_FMHA: str = "1d876ee612888821b168c25ffa75a9dcbb963aaa/fmha/trtllm-gen/"
+    TRTLLM_GEN_FMHA: str = "5d4df6c2647e860992d1cc57ced05204b55f3787/fmha/trtllm-gen/"

Before:

-        "1abeea012a8779c6df5b84332fad43c6cfc3b257fe5ab883c8ea501464010d16"
+        "681e69c9c0215b4780eaf92f897c2dc94285a9143b90f765085eba83af5afa5b"

이러한 변경은 최신 성능 최적화가 적용된 커널을 사용하기 위한 필수적인 과정입니다.

2. 동적 토큰 페이지 커널 도입 (include/flashinfer/trtllm/fmha/fmhaKernels.cuh)

이 PR의 핵심적인 기능 개선은 fmhaKernels.cuh 파일에 구현되었습니다. 특히 TllmGenFmhaKernel 클래스에 동적 토큰 페이지 기능을 위한 로직이 추가되었습니다.

새로운 상수 정의:

+  static constexpr int kDynamicNumTokensPerPageThreshold = 128;
+  static constexpr int kDynamicNumTokensPerPageKernelKey = 128;

kDynamicNumTokensPerPageThreshold는 동적 토큰 페이지 커널을 활성화하기 위한 최소 페이지 크기(128)를 정의합니다. kDynamicNumTokensPerPageKernelKey는 해당 커널에서 사용할 기본 페이지 크기입니다.

numTokensPerPage 검증 로직 수정:

기존에는 numTokensPerPage가 반드시 2의 거듭제곱이어야 했지만, 이제는 0 (사용하지 않음) 또는 2의 거듭제곱이면 허용됩니다.

Before:

-    // The numTokensPerPage must be power of 2.
-    FLASHINFER_CHECK((numTokensPerPage & (numTokensPerPage - 1)) == 0,
-                     "The numTokensPerPage must be power of 2.");
+    // The numTokensPerPage must be 0 (unused for non-paged kernels) or power of 2.
+    FLASHINFER_CHECK(numTokensPerPage == 0 || ((numTokensPerPage & (numTokensPerPage - 1)) == 0),
+                     "The numTokensPerPage must be 0 or power of 2.");

TllmGenFmhaKernel::selectKernel 로직 변경:

selectKernel 함수는 입력 파라미터에 따라 최적의 커널을 선택하는 역할을 합니다. 이 함수 내에서 selectNumTokensPerPage 함수가 호출되어 numTokensPerPagemDynamicNumTokensPerPage 플래그를 설정합니다.

새로운 useDynamicNumTokensPerPage 함수:

+  inline bool useDynamicNumTokensPerPage(RunnerParams const& params) const {
+    return isPagedKv(params.mQkvLayout) && !params.mSparseMla && params.mNumHeadsQPerKv > 1 &&
+           params.mHeadDimQk == params.mHeadDimV &&
+           params.mNumTokensPerPage >= kDynamicNumTokensPerPageThreshold;
+  }

이 함수는 페이징된 KV 캐시 레이아웃을 사용하고, 희소 MLA가 아니며, 헤드 차원이 동일하고, 페이지 크기가 임계값 이상일 때 동적 토큰 페이지 커널을 사용하도록 결정합니다.

새로운 selectNumTokensPerPage 함수:

+  void selectNumTokensPerPage(RunnerParams const& params,
+                              SelectKernelParams& selectKernelParams) const {
+    selectKernelParams.mDynamicNumTokensPerPage = false;
+    if (params.mSparseMla) {
+      // SparseMla kernels use a fixed numTokensPerPage = 1.
+      selectKernelParams.mNumTokensPerPage = 1;
+    } else if (!isPagedKv(params.mQkvLayout)) {
+      // NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels.
+      selectKernelParams.mNumTokensPerPage = 0;
+    } else if (useDynamicNumTokensPerPage(params)) {
+      FLASHINFER_CHECK((params.mNumTokensPerPage & (params.mNumTokensPerPage - 1)) == 0,
+                       "Dynamic numTokensPerPage requires a power-of-2 page size, got %d.",
+                       params.mNumTokensPerPage);
+      selectKernelParams.mDynamicNumTokensPerPage = true;
+      selectKernelParams.mNumTokensPerPage = kDynamicNumTokensPerPageKernelKey;
+    } else {
+      selectKernelParams.mNumTokensPerPage = params.mNumTokensPerPage;
+    }
+  }

이 함수는 위에서 정의된 useDynamicNumTokensPerPage 로직을 기반으로 selectKernelParamsmNumTokensPerPagemDynamicNumTokensPerPage를 설정합니다. 동적 토큰 페이지 커널이 선택되면 mDynamicNumTokensPerPagetrue로 설정되고, mNumTokensPerPagekDynamicNumTokensPerPageKernelKey 값으로 고정됩니다.

MLA 생성 커널 로직 확장:

isMlaGenKernel 함수가 확장되어 더 많은 헤드/차원 조합을 MLA 생성 커널로 인식하도록 변경되었습니다.

Before:

-  return params.mHeadDimQk == 576 && params.mHeadDimV == 512;
+  return (params.mHeadDimQk == 576 && params.mHeadDimV == 512) ||
+         (params.mHeadDimQk == 320 && params.mHeadDimV == 256);

슬라이딩 윈도우 및 청크드 어텐션 지원 제한:

리뷰 과정에서 발견된 중요한 변경 사항으로, MLA 생성 커널에서 슬라이딩 윈도우 및 청크드 어텐션 지원이 제거되었습니다. 이는 MLA 커널이 Dense 마스크를 가정하고 동작하기 때문입니다.

+      FLASHINFER_CHECK(
+          params.mMaxSeqLenKv <= params.mAttentionWindowSize &&
+              params.mChunkedAttentionSize == INT_MAX,
+          "TRTLLM-GEN MLA generation does not support sliding-window or chunked attention.");

FLASHINFER_CHECK는 MLA 생성 커널 사용 시 슬라이딩 윈도우나 청크드 어텐션이 활성화되어 있으면 오류를 발생시켜, 잘못된 결과가 반환되는 것을 방지합니다. 이는 'fail-loud' 원칙을 따르는 좋은 설계입니다.

3. 테스트 및 기타 변경사항

  • tests/attention/test_attention_sink_blackwell.py: FP16 Blackwell 어텐션 싱크 컨텍스트 허용 오차를 완화하여 관찰된 정밀도 노이즈를 맞춥니다. 이는 하드웨어의 미세한 차이를 수용하여 테스트 안정성을 높입니다.
  • tests/attention/test_trtllm_gen_attention.py: 동적 토큰 페이지 테스트를 위한 새로운 테스트 케이스가 추가되었습니다. 이는 새로운 기능의 정확성을 검증하는 데 필수적입니다.
  • include/flashinfer/trtllm/fmha/kernelParams.h: tmaKSlidingWindowKvPool_ptrSparseMlaTopKLens와 같은 새로운 멤버 변수가 추가되었습니다. 이는 새로운 커널 기능(예: 희소 MLA 커널)을 지원하기 위한 것입니다.

왜 이게 좋은가?

1. 동적 토큰 페이지 커널의 이점

이 PR의 핵심은 '동적 토큰 페이지' 커널의 도입입니다. LLM 추론 시, 특히 KV 캐시 관리는 성능에 큰 영향을 미칩니다. KV 캐시는 이전 토큰들의 Key와 Value 벡터를 저장하는데, 이 캐시를 효율적으로 관리하는 것이 중요합니다.

기존의 고정된 페이지 크기 방식은 다음과 같은 비효율을 야기할 수 있습니다:

  • 메모리 낭비: 시퀀스 길이가 페이지 크기보다 훨씬 짧을 경우, 페이지의 일부만 사용되고 나머지는 낭비됩니다.
  • 불필요한 오버헤드: 페이지 크기가 너무 작으면 페이지 테이블 관리에 오버헤드가 발생할 수 있습니다.

동적 토큰 페이지 커널은 페이지 크기를 동적으로 조정하거나, 더 효율적인 방식으로 페이지를 활용할 수 있게 합니다. 특히 페이지 크기가 128 이상일 때 이 커널을 사용하도록 설정된 것은, 해당 크기 이상에서 페이지 기반 접근 방식이 더 큰 이점을 보이기 때문일 수 있습니다. 이는 GPU 메모리 대역폭 활용도를 높이고, 불필요한 메모리 접근을 줄여 추론 속도를 향상시킬 수 있습니다.

성능 향상:

PR 설명에 따르면, 이 변경은 특히 페이지 크기가 128 이상일 때 GQA 디코드 및 프리필 성능을 개선합니다. 구체적인 성능 수치는 제공되지 않았지만, 이러한 최적화는 일반적으로 수 밀리초에서 수십 밀리초의 응답 시간 단축으로 이어질 수 있으며, 이는 LLM 서비스의 처리량을 크게 향상시킵니다.

2. MLA 생성 커널 지원 확장

isMlaGenKernel 함수의 확장은 더 넓은 범위의 모델 구성에서 MLA(Multi-head Linear Attention) 커널을 활용할 수 있게 합니다. MLA는 기존의 Scaled Dot-Product Attention보다 계산 복잡도가 낮아 성능 향상에 기여할 수 있습니다. 더 많은 헤드 및 차원 조합을 지원함으로써, 더 많은 사용 사례에서 MLA의 이점을 누릴 수 있게 되었습니다.

3. 'Fail-loud' 원칙 준수

MLA 생성 커널에서 슬라이딩 윈도우 및 청크드 어텐션 지원을 명시적으로 제거하고 FLASHINFER_CHECK를 사용한 것은 매우 좋은 설계입니다. 이는 사용자가 의도치 않게 지원되지 않는 기능을 사용하려 할 때, 오류 메시지를 통해 명확하게 인지하고 수정할 수 있도록 합니다. 이는 디버깅 시간을 단축하고, 잠재적인 버그를 사전에 방지하는 데 큰 도움이 됩니다.

4. 테스트 안정성 향상

test_attention_sink_blackwell.py에서 FP16 허용 오차를 조정한 것은 실제 하드웨어 환경에서의 미묘한 정밀도 차이를 수용하기 위한 실용적인 조치입니다. 이는 테스트의 신뢰성을 높이고, 불필요한 테스트 실패를 줄여 개발 프로세스를 원활하게 합니다.

일반적인 교훈

  1. 하드웨어 특성 활용: GPU 아키텍처(예: Blackwell)와 메모리 접근 패턴(페이지 기반 KV 캐시)의 특성을 깊이 이해하고 이를 커널 최적화에 활용하는 것이 중요합니다.
  2. 동적 커널 선택: 다양한 입력 조건(페이지 크기, 헤드 차원 등)에 따라 최적의 커널을 동적으로 선택하는 메커니즘은 LLM 추론 성능을 극대화하는 핵심 요소입니다.
  3. 'Fail-loud' 원칙: 지원되지 않는 기능 조합이나 잘못된 사용 사례에 대해서는 명확한 오류 메시지를 통해 사용자에게 알리는 것이 장기적으로 더 나은 개발 경험과 시스템 안정성을 보장합니다.
  4. 지속적인 아티팩트 관리: 최적화된 커널은 종종 새로운 바이너리(cubins)로 제공됩니다. 이러한 아티팩트의 경로와 무결성(체크섬)을 최신 상태로 유지하는 것은 라이브러리 유지보수의 중요한 부분입니다.

결론

이번 FlashInfer PR은 TRTLLM-GEN GQA 커널에 동적 토큰 페이지 기능을 도입함으로써 LLM 추론 성능을 한 단계 끌어올렸습니다. 또한, MLA 생성 커널의 지원 범위를 넓히고, 오류 처리 방식을 개선하는 등 다방면에 걸친 최적화를 이루었습니다. 이러한 개선은 더 빠르고 효율적인 LLM 추론을 가능하게 하며, AI 모델의 실질적인 활용도를 높이는 데 기여할 것입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글