본문으로 건너뛰기

[flashinfer] FlashInfer, MoE 및 FP8 GEMM 성능 향상을 위한 커널 업데이트

PR 링크: flashinfer-ai/flashinfer#3239 상태: Merged | 변경: +0 / -0

들어가며

대규모 언어 모델(LLM)의 발전과 함께 Mixture-of-Experts (MoE) 아키텍처가 주목받고 있습니다. MoE 모델은 더 적은 연산량으로 더 큰 모델을 구현할 수 있다는 장점이 있지만, 이를 효율적으로 처리하기 위한 고성능 컴퓨팅 라이브러리의 최적화가 필수적입니다. NVIDIA의 FlashInfer는 이러한 요구에 부응하여 GPU 상에서 행렬 곱셈(GEMM) 및 관련 연산을 가속하는 라이브러리입니다. 이번 PR은 FlashInfer의 핵심 컴포넌트 중 하나인 trtllm_batched_gemm_runner.cu 및 관련 커널 파일들을 업데이트하여, 특히 mxfp4 MoE, DeepSeek-V3 MoE, 그리고 mxfp8 GEMM 연산의 성능을 향상시키는 데 중점을 두고 있습니다.

이 PR은 단순히 성능 개선뿐만 아니라, 다양한 GPU 아키텍처 및 데이터 타입에 대한 호환성과 정확성을 높이는 중요한 변경을 포함하고 있습니다. 본 글에서는 이 PR의 주요 코드 변경 사항을 분석하고, 이러한 변경이 왜 성능 향상과 안정성 증대에 기여하는지 상세히 설명하겠습니다.

코드 분석

이번 PR의 변경 사항은 주로 CUDA 커널 로딩 및 선택 로직, 그리고 아티팩트 경로 관리에 집중되어 있습니다. 각 파일을 중심으로 변경 내용을 살펴보겠습니다.

1. csrc/trtllm_batched_gemm_runner.cu

이 파일은 TensorRT-LLM의 배치 행렬 곱셈(Batched GEMM) 실행을 관리하는 클래스(TrtllmGenBatchedGemmRunner)의 구현을 담고 있습니다. 주요 변경 사항은 다음과 같습니다.

GPU SM 버전별 호환성 강화

이전 코드에서는 bmm.getBatchedGemmConfigs()를 통해 얻은 설정들을 단순히 순회하며 mOptions와 일치하는지 확인했습니다. 하지만 이번 업데이트에서는 GPU의 스트리밍 멀티프로세서(SM) 버전에 따른 커널 호환성을 명시적으로 검사하는 로직이 추가되었습니다.

Before:

@@ -91,9 +91,11 @@
   auto const configs = bmm.getBatchedGemmConfigs();
 
   mPassingConfigIndices.clear();
+  auto sm_version = getSMVersion();
 
   for (size_t i = 0; i < bmm.getNumBatchedGemmConfigs(); ++i) {
-    auto const options = configs[i].mOptions;
+    auto const config = configs[i];
+    auto const options = config.mOptions;
     auto const tileSize = mOptions.transposeMmaOutput ? options.mTileN : options.mTileM;
     // When we include low-latency kernels we can set transposeMmaOutput via constructor
     if (options.mDtypeA == mOptions.dtypeA && options.mDtypeB == mOptions.dtypeB &&
@@ -119,7 +121,13 @@
       if ((int64_t)options.mEltwiseActType != (int64_t)mOptions.eltwiseActType) {
         continue;
       }
-
+      // if patchF2fp is enabled, sm100f cubins cannot be used for sm103
+      if (options.mPatchF2fp && sm_version == 103) {
+        if (config.mSm != tg::CudaArch::Sm103a) continue;
+      }
+      if (options.mPatchF2fp && sm_version == 100) {
+        if (config.mSm != tg::CudaArch::Sm100a && config.mSm != tg::CudaArch::Sm100f) continue;
+      }
       if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM) {
         mPassingConfigIndices.push_back(i);
       }

getSMVersion() 함수를 통해 현재 실행 중인 GPU의 SM 버전을 가져옵니다. 이후, options.mPatchF2fp 플래그가 활성화된 경우, 특정 SM 버전(예: SM 103)에서는 특정 커널(예: Sm100f)을 사용하지 않도록 필터링합니다. 이는 patchF2fp 기능이 특정 아키텍처에 최적화되어 있거나, 다른 아키텍처에서는 호환성 문제가 발생할 수 있기 때문입니다. 이러한 조건부 필터링은 런타임 시점에 더 적합하고 안정적인 커널을 선택하게 하여 오류를 방지하고 성능을 보장합니다.

2. csrc/trtllm_gemm_runner.cu

이 파일은 TensorRT-LLM의 일반 GEMM 연산을 위한 커널 선택 로직을 포함합니다. FP8 데이터 타입을 사용하는 커널 선택 과정에서 변경이 있었습니다.

Before:

@@ -44,21 +44,21 @@
 struct TrtllmGenGemmRunnerOptions {
 int64_t select_kernel_fp8(int32_t M, int32_t N, int32_t K, 
                           const gemm::gemm::GemmInterface& interface) {
   static constexpr const char* KERNEL_NAME_HIGH_N_K_RATIO = 
-      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128u2_s6_et64x8_m64x8x32_c1x1x1_16dp256b_rM_TN_"
+      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128u2_s6_et64x8_m64x8x32_c1x1x1_rM_TN_"
       "transOut_"
-      "noShflA_dsFp8_schPd2x2x1x3_sm100f";
+      "noShfl_dsFp8_schPd2x2x1x3_sm100f";
 
   static constexpr const char* KERNEL_NAME_LOW_N_K_RATIO = 
-      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_c1x1x1_16dp256b_rM_TN_"
-      "transOut_noShflA_dsFp8_schedS_sm100f";
+      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_c1x1x1_rM_TN_"
+      "transOut_noShfl_dsFp8_schedS_sm100f";
 
   static constexpr const char* KERNEL_NAME_LARGE_N = 
-      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_c1x1x1_16dp256b_rM_TN_"
-      "transOut_noShflA_dsFp8_schPd2x2x1x3_sm100f";
+      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_c1x1x1_rM_TN_"
+      "transOut_noShfl_dsFp8_schPd2x2x1x3_sm100f";
 
   static constexpr const char* KERNEL_NAME_DEFAULT = 
-      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128u2_s6_et64x16_m64x16x32_c1x1x1_16dp256b_rM_TN_"
-      "transOut_noShflA_dsFp8_schedS_sm100f";
+      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128u2_s6_et64x16_m64x16x32_c1x1x1_rM_TN_"
+      "transOut_noShfl_dsFp8_schedS_sm100f";
 
   double const n_k_ratio = static_cast<double>(N) / static_cast<double>(K);
 

주요 변경점은 커널 이름 문자열에서 _16dp256b 부분이 제거되고, _noShflA_noShfl로 변경된 것입니다. 이는 사용되는 커널의 세부 옵션이 변경되었음을 나타냅니다. 특히 _16dp256b는 특정 메모리 접근 패턴이나 데이터 배치와 관련이 있을 수 있으며, 이를 제거함으로써 더 일반적이거나 다른 최적화된 패턴을 사용하는 커널로 대체되었을 가능성이 있습니다. _noShflA에서 _noShfl로의 변경은 스레드 간 셔플(shuffle) 연산 사용 여부에 대한 명칭 규칙 변경일 수 있으며, 이는 커널 구현의 미세 조정과 관련이 있습니다. 이러한 변경은 FP8 연산의 효율성을 높이기 위한 커널의 재구성 또는 업데이트를 반영합니다.

3. csrc/trtllm_low_latency_gemm_runner.cu

이 파일은 저지연 GEMM 커널 선택 로직을 담당합니다. 여기에서도 커널 이름 문자열에 대한 업데이트가 이루어졌습니다.

Before:

@@ -63,28 +63,28 @@
  */
 int64_t select_kernel(int32_t m, int32_t n, int32_t k, const gemm::gemm::GemmInterface& interface) {
   static constexpr const char* KERNEL_MMAN_8_TILEK_128 = 
-      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128_s7_et128x8_m128x8x32_c1x1x1_16dp256b_rM_BN_"
+      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128_s7_et128x8_m128x8x32_c1x1x1_rM_BN_"
       "transOut_schedS_sm100f";
   static constexpr const char* KERNEL_MMAN_8_TILEK_256 = 
-      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x256_s4_et128x8_m128x8x32_c1x1x1_16dp256b_rM_BN_"
+      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x256_s4_et128x8_m128x8x32_c1x1x1_rM_BN_"
       "transOut_schedS_sm100f";
   static constexpr const char* KERNEL_MMAN_16_TILEK_128 = 
-      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x128_s7_et128x32_m128x64x32_c1x1x1_16dp256b_rM_BN_"
+      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x128_s7_et128x32_m128x64x32_c1x1x1_rM_BN_"
       "transOut_schedS_sm100f";
   static constexpr const char* KERNEL_MMAN_16_TILEK_256 = 
-      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x256_s3_et128x32_m128x64x32_c1x1x1_16dp256b_rM_BN_"
+      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x256_s3_et128x32_m128x64x32_c1x1x1_rM_BN_"
       "transOut_schedS_sm100f";
   static constexpr const char* KERNEL_MMAN_32_TILEK_128 = 
-      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128_s9_et128x32_m128x32x32_c1x1x1_16dp256b_rM_BN_"
+      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128_s9_et128x32_m128x32x32_c1x1x1_rM_BN_"
       "transOut_schedS_sm100f";
   static constexpr const char* KERNEL_MMAN_32_TILEK_256 = 
-      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x256_s5_et128x32_m128x32x32_c1x1x1_16dp256b_rM_BN_"
+      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x256_s5_et128x32_m128x32x32_c1x1x1_rM_BN_"
       "transOut_schedS_sm100f";
   static constexpr const char* KERNEL_MMAN_64_TILEK_128 = 
-      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128_s7_et128x16_m128x16x32_c1x1x1_16dp256b_rM_BN_"
+      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128_s7_et128x16_m128x16x32_c1x1x1_rM_BN_"
       "transOut_schedS_sm100f";
   static constexpr const char* KERNEL_MMAN_64_TILEK_256 = 
-      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x256_s5_et128x16_m128x16x32_c1x1x1_16dp256b_rM_BN_"
+      "gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x256_s5_et128x16_m128x16x32_c1x1x1_rM_BN_"
       "transOut_schedS_sm100f";
 
   std::string kernel_name;

마찬가지로, _16dp256b 접두사가 제거되고 _rM_BN_으로 변경되었습니다. 이는 저지연 GEMM 커널의 내부 구현이 업데이트되었음을 시사합니다. _rM_BN_은 아마도 행렬 곱셈의 결과(Result Matrix)와 배치(Batch) 연산에 대한 새로운 최적화 기법을 나타낼 수 있습니다. 이러한 변경은 특정 연산 패턴에서 더 낮은 지연 시간을 달성하기 위한 커널 튜닝의 결과로 보입니다.

4. flashinfer/artifacts.py

이 파일은 FlashInfer가 사용하는 사전 컴파일된 CUDA 커널(cubin) 파일들의 경로와 체크섬(checksum)을 관리합니다. 이번 PR에서는 이러한 아티팩트들의 경로와 해시값이 업데이트되었습니다.

Before:

@@ -137,10 +137,10 @@
 
     TRTLLM_GEN_FMHA: str = "1d876ee612888821b168c25ffa75a9dcbb963aaa/fmha/trtllm-gen/"
     TRTLLM_GEN_BMM: str = (
-        "3d9dd08b1691e63e298a7b862d74fd7af3daf594/batched_gemm-4fc8a68-6743435/"
+        "c21ddd11585c1eea5764927465d0be15dd957e45/batched_gemm-91e0ba0-da44fdf/"
     )
     TRTLLM_GEN_GEMM: str = (
-        "31e75d429ff3f710de1251afdd148185f53da44d/gemm-4daf11e-1fddea2/"
+        "10f64528a1172dae8e29601a3b99ab9dc78d37be/gemm-91e0ba0-2710384/"
     )
     CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
     # For DEEPGEMM, we also need to update KernelMap.KERNEL_MAP_HASH in flashinfer/deep_gemm.py
@@ -160,11 +160,11 @@
         "1abeea012a8779c6df5b84332fad43c6cfc3b257fe5ab883c8ea501464010d16"
     )
     TRTLLM_GEN_BMM: str = (
-        "44174e2a08bb427088f5b5443bf0108bb6fb6cb0812ff6018f6418b3d2273824"
+        "4a3ed9c3dc6547ea3eed01ebda75b0e4322f6c01fc40cd2a4978e4deaba2732a"
     )
     DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
     TRTLLM_GEN_GEMM: str = (
-        "64b7114a429ea153528dd4d4b0299363d7320964789eb5efaefec66f301523c7"
+        "f97f90f9ce1dab73eb3d7c90fca4bbd52687642dd87a79dd10b77d7802b25c33"
     )
     # SHA256 of the checksums.txt manifest file per cpu-arch/sm-arch,
     # NOT hashes of individual kernel .so files.

TRTLLM_GEN_BMMTRTLLM_GEN_GEMM에 해당하는 경로와 체크섬 해시값이 변경되었습니다. 이는 컴파일된 커널 바이너리 파일 자체가 업데이트되었음을 의미합니다. 새로운 해시값은 해당 커널들이 성능 개선 또는 버그 수정을 포함하고 있음을 나타냅니다. flashinfer/artifacts.py는 이러한 바이너리들을 다운로드하고 검증하는 데 사용되므로, 이 파일의 업데이트는 라이브러리가 최신 성능 최적화 커널을 사용하도록 보장하는 데 필수적입니다.

왜 이게 좋은가?

이 PR은 여러 측면에서 좋은 최적화 및 개선을 포함하고 있습니다.

  1. 성능 향상: MoE 및 FP8 GEMM 커널의 업데이트는 직접적으로 관련 연산의 속도 향상을 가져옵니다. 특히 MoE 모델은 LLM에서 점점 더 중요해지고 있으며, FP8은 모델 양자화 및 추론 속도 향상을 위한 핵심 기술입니다. 이러한 연산의 성능 개선은 LLM 추론의 전체적인 효율성을 높이는 데 크게 기여합니다.
  2. 호환성 및 안정성 증대: trtllm_batched_gemm_runner.cu 파일에서 SM 버전별 호환성을 강화한 것은 다양한 NVIDIA GPU 아키텍처에서 라이브러리가 안정적으로 작동하도록 보장합니다. 특정 아키텍처에 맞지 않는 커널을 사용하려 할 때 발생하는 런타임 오류를 방지하고, 각 아키텍처의 특성을 최대한 활용할 수 있도록 합니다.
  3. 최신 커널 활용: flashinfer/artifacts.py의 업데이트는 라이브러리가 최신 버전의 TensorRT-LLM 커널을 사용하도록 합니다. TensorRT-LLM은 NVIDIA의 최신 GPU 기술과 최적화 기법을 반영하여 지속적으로 업데이트되므로, 이를 FlashInfer에서 활용하는 것은 성능 및 기능 개선의 핵심입니다.
  4. 코드 명확성 및 유지보수성: 커널 이름 문자열의 변경은 때로는 가독성을 해칠 수도 있지만, 여기서는 특정 최적화 기법(예: _16dp256b 제거, _noShflA -> _noShfl)을 반영하는 것으로 보입니다. 이는 커널의 동작 방식을 더 명확하게 이해하고 향후 유지보수하는 데 도움을 줄 수 있습니다.

일반적 교훈:

  • 고성능 컴퓨팅 라이브러리에서는 특정 하드웨어 아키텍처(GPU SM 버전)에 대한 호환성 검증이 매우 중요합니다. 런타임 시점에 적절한 커널을 선택하거나, 호환되지 않는 커널 사용을 방지하는 로직은 안정성과 성능을 모두 보장합니다.
  • FP8과 같은 저정밀도 연산은 성능 향상의 핵심 동력이지만, 이를 지원하는 커널의 최적화는 복잡하고 지속적인 노력이 필요합니다. 커널 이름 문자열의 변경은 이러한 최적화 과정의 일부일 수 있습니다.
  • 사전 컴파일된 바이너리(cubin)를 사용하는 경우, 해당 바이너리의 경로와 무결성 검증(checksum)을 관리하는 메커니즘이 필수적입니다. 이러한 아티팩트 관리 파일의 업데이트는 라이브러리가 항상 최적의 성능을 내는 버전을 사용하도록 보장합니다.

결론

이번 PR은 FlashInfer가 MoE 및 FP8 GEMM 연산의 성능을 최적화하고, 다양한 GPU 환경에서의 호환성을 개선하는 중요한 업데이트입니다. 커널 로직의 미세 조정과 최신 아티팩트 활용을 통해 라이브러리의 전반적인 효율성과 안정성을 향상시켰습니다. 이는 LLM 추론 성능을 극대화하려는 FlashInfer의 지속적인 노력의 일환으로 볼 수 있습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글