[flashinfer] FlashInfer, CuTe DSL 기반 FMHA 커널 통합으로 사전 생성(Prefill) 성능 극대화
PR 링크: flashinfer-ai/flashinfer#3039 상태: Merged | 변경: +None / -None
들어가며
최근 대규모 언어 모델(LLM)의 발전은 이전과는 비교할 수 없는 규모의 연산을 요구하고 있습니다. 특히, LLM의 추론 과정에서 발생하는 사전 생성(Prefill) 단계는 전체 응답 생성 시간의 상당 부분을 차지하며, 이 단계의 효율성은 모델의 실질적인 성능에 직접적인 영향을 미칩니다. NVIDIA의 최신 GPU 아키텍처인 Blackwell(SM100/SM103/SM110)은 이러한 연산 부담을 줄이기 위한 새로운 기능들을 포함하고 있습니다.
이번 PR은 FlashInfer 라이브러리에 NVIDIA의 CuTe DSL(Domain Specific Language)을 사용하여 사전 컴파일된 FMHA(Fused Multi-Head Attention) 커널을 통합하는 중요한 변경사항을 담고 있습니다. 이 통합은 특히 Blackwell 아키텍처에서 사전 생성(Prefill) 단계의 성능을 크게 향상시키는 것을 목표로 합니다. 기존의 Just-In-Time (JIT) 컴파일 방식 대신, 미리 컴파일된 커널(cubin)을 런타임에 로드함으로써 컴파일 오버헤드를 제거하고, TensorRT-LLM의 trtllm_ragged_attention_deepseek() API를 통해 이 새로운 커널을 활용할 수 있도록 합니다.
이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, CuTe DSL 커널 통합이 왜 성능 향상으로 이어지는지, 그리고 이 최적화가 가지는 기술적 의미는 무엇인지 살펴보겠습니다.
코드 분석
이번 PR의 핵심 변경 사항은 주로 세 파일에 집중되어 있습니다:
flashinfer/attention_dsl/cute_dsl/fmha.py: CuTe DSL FMHA 커널 로딩, 변형 선택, ragged prefill 진입점 관련 로직을 담당합니다.flashinfer/artifacts.py: CuTe DSL FMHA 커널에 대한 아티팩트 경로 및 체크섬 정보를 관리합니다. 이는 커널 파일의 무결성을 보장하는 데 중요합니다.flashinfer/prefill.py: TensorRT-LLM의trtllm_ragged_attention_deepseek()API에 CuTe DSL 백엔드를 통합하는 로직을 포함합니다.
benchmarks/routines/attention.py 및 benchmarks/routines/flashinfer_benchmark_utils.py
이 파일들은 새로운 cute-dsl 백엔드를 지원하도록 수정되었습니다. 특히, testBatchPrefillWithRaggedKVCacheWrapper 함수에서는 cute-dsl 백엔드를 사용할 때 입력 텐서(Q, K, V)와 출력 텐서(O)에 대한 front_pad 로직이 추가되었습니다. 이는 CuTe DSL의 ragged varlen 커널이 데이터 시작 부분에 유효한 GPU 메모리가 필요하기 때문입니다. 또한, flashinfer_benchmark_utils.py의 SUPPORTED_BACKENDS 딕셔너리에 CUDA 버전 10.0 및 10.3에 대해 cute-dsl 백엔드가 추가되어, 해당 환경에서 벤치마킹이 가능하도록 설정되었습니다.
Before:
q = torch.randn(
- cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype
+ cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype
)
if args.verbose >= 2:
print(f"[VVERBOSE] {q.shape = }")
- k = torch.randn(
- cumsum_s_kv, num_kv_heads, head_dim_qk, device=device, dtype=kv_init_dtype
+ k_full = torch.randn(
+ front_pad_kv + cumsum_s_kv,
+ num_kv_heads,
+ head_dim_qk,
+ device=device,
+ dtype=kv_init_dtype,
)
- v = torch.randn(
- cumsum_s_kv, num_kv_heads, head_dim_vo, device=device, dtype=kv_init_dtype
+ v_full = torch.randn(
+ front_pad_kv + cumsum_s_kv,
+ num_kv_heads,
+ head_dim_vo,
+ device=device,
+ dtype=kv_init_dtype,
)
+ k = k_full[front_pad_kv:]
+ v = v_full[front_pad_kv:]
After:
+ # Front-padding for cute-dsl varlen kernel: the persistent varlen kernel
+ # applies a negative pointer offset (-max_s * H * D), so there must be
+ # valid GPU memory before the data start.
+ front_pad_q = s_qo if "cute-dsl" in backends else 0
+ front_pad_kv = s_kv if "cute-dsl" in backends else 0
+
+ q_full = torch.randn(
+ front_pad_q + cumsum_s_qo,
+ num_qo_heads,
+ head_dim_qk,
+ device=device,
+ dtype=q_init_dtype,
+ )
+ q = q_full[front_pad_q:]
+ k_full = torch.randn(
+ front_pad_kv + cumsum_s_kv,
+ num_kv_heads,
+ head_dim_qk,
+ device=device,
+ dtype=kv_init_dtype,
+ )
+ v_full = torch.randn(
+ front_pad_kv + cumsum_s_kv,
+ num_kv_heads,
+ head_dim_vo,
+ device=device,
+ dtype=kv_init_dtype,
+ )
+ k = k_full[front_pad_kv:]
+ v = v_full[front_pad_kv:]
또한, trtllm_out 텐서 할당 시에도 cute-dsl 백엔드를 위해 out_pad 로직이 추가되었습니다. 이는 출력 텐서 역시 입력과 유사한 패딩 요구사항을 가질 수 있음을 시사합니다.
Before:
- if "trtllm-native" in backends:
- trtllm_out = torch.empty(
- q.shape[0],
+ if "trtllm-native" in backends or "cute-dsl" in backends:
+ # cute-dsl varlen kernel uses negative pointer offsets on output,
+ # so front-pad like Q/K/V.
+ out_pad = front_pad_q if "cute-dsl" in backends else 0
+ trtllm_out_full = torch.empty(
+ out_pad + q.shape[0],
After:
q.shape[1],
v.shape[2],
device=q.device,
dtype=out_dtype,
)
+ trtllm_out = trtllm_out_full[out_pad:]
flashinfer/artifacts.py
이 파일은 새로운 CuTe DSL FMHA 커널에 대한 저장소 경로(ArtifactPath.DSL_FMHA)와 지원하는 아키텍처(ArtifactPath.DSL_FMHA_ARCHS)를 정의합니다. 또한, 각 CPU 아키텍처(x86_64, aarch64) 및 SM 아키텍처(sm_100a, sm_103a, sm_110a)별로 checksums.txt 파일의 SHA256 해시값을 CheckSumHash.DSL_FMHA_CHECKSUMS에 추가했습니다. 이는 get_subdir_file_list 함수에서 해당 커널 파일들을 다운로드하고 무결성을 검증하는 데 사용됩니다. _get_host_cpu_arch 함수는 현재 실행 중인 호스트의 CPU 아키텍처를 감지하여 올바른 아티팩트 경로를 선택하도록 돕습니다.
Before: (일부 발췌)
class ArtifactPath:
CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
+
+
class CheckSumHash:
TRTLLM_GEN_GEMM: str = \
"64b7114a429ea153528dd4d4b0299363d7320964789eb5efaefec66f301523c7"
After: (일부 발췌)
class ArtifactPath:
CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
+ DSL_FMHA: str = "c770c91cb0d991b7828fc85d2253a62f0d356b6c/fmha/cute-dsl/"
+ DSL_FMHA_ARCHS: tuple[str, ...] = ("sm_100a", "sm_103a", "sm_110a")
+
+
class CheckSumHash:
TRTLLM_GEN_GEMM: str = \
"64b7114a429ea153528dd4d4b0299363d7320964789eb5efaefec66f301523c7"
+ # SHA256 of the checksums.txt manifest file per cpu-arch/sm-arch,
+ # NOT hashes of individual kernel .so files.
+ DSL_FMHA_CHECKSUMS: dict[str, dict[str, str]] = {
+ "x86_64": {
+ "sm_100a": "9533536698cdc256d897fffb3114de317076654ff8630ff283d850cc3dc96d86",
+ "sm_103a": "927e1954f1d45b0ee876f139084e4facdfcc87e86f4d30cb92d5c33698d4c2d6",
+ "sm_110a": "277b1dceaab2081e3def37cf997280a3f2c3ac515d22b80be141253c0278b8b5",
+ },
+ "aarch64": {
+ "sm_100a": "b48ed0bcc9bad4afd33e0784c8c9eb9e13e782afe197816b1d0747b11759493e",
+ "sm_103a": "bace619a560f3ce52ad6ba105fffb8ea8629fe57885a90892c9e15a7122467e1",
+ "sm_110a": "d8369bcfa443bfd791cd014e3b030d378f00a975db8278eebd5b2fb529e3257d",
+ },
+ }
+ map_checksums: dict[str, str] = {
+ safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA,
+ safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "checksums.txt"): TRTLLM_GEN_BMM,
+ safe_urljoin(ArtifactPath.DEEPGEMM, "checksums.txt"): DEEPGEMM,
+ safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "checksums.txt"): TRTLLM_GEN_GEMM,
+ **{
+ safe_urljoin(
+ ArtifactPath.DSL_FMHA, f"{cpu_arch}/{sm_arch}/checksums.txt"
+ ): sha
+ for cpu_arch, sm_checksums in DSL_FMHA_CHECKSUMS.items()
+ for sm_arch, sha in sm_checksums.items()
+ },
+ }
flashinfer/prefill.py
이 파일은 trtllm_ragged_attention_deepseek() 함수에 backend="cute-dsl" 옵션을 추가하여 CuTe DSL FMHA 커널을 호출할 수 있도록 합니다. 리뷰어의 피드백에 따라, 기존 trtllm-native 대신 trtllm-gen으로 명칭을 통일하는 수정도 이루어졌습니다. 또한, flashinfer/attention/_core.py (이전 flashinfer/attention.py) 파일에서 flashinfer.prefill 모듈 임포트 경로가 flashinfer.prefill에서 ../prefill로 변경되었습니다. 이는 내부 모듈 구조 변경에 따른 것입니다.
Before: (일부 발췌)
from .api_logging import flashinfer_api
from .jit import gen_batch_attention_module
from .utils import (
- MaskMode,
- PosEncodingMode,
- TensorLayout,
- _unpack_paged_kv_cache,
- determine_attention_backend,
)
-from .prefill import BatchPrefillWithPagedKVCacheWrapper
-from .jit.attention.variants import attention_sink_decl
-from .jit.utils import filename_safe_dtype_map
+from ..api_logging import flashinfer_api
+from ..jit import gen_batch_attention_module
+from ..utils import (
+ MaskMode,
+ PosEncodingMode,
+ TensorLayout,
+ _unpack_paged_kv_cache,
+ determine_attention_backend,
)
+from ..prefill import BatchPrefillWithPagedKVCacheWrapper
+from ..jit.attention.variants import attention_sink_decl
+from ..jit.utils import filename_safe_dtype_map
After: (일부 발췌)
return backend_wrappers[backend].run_return_lse(q, k, v)[0]
+ elif backend == "cute-dsl":
+ _q_scale = q_scale if q_scale is not None else 1.0
+ _k_scale = k_scale if k_scale is not None else 1.0
+ _v_scale = v_scale if v_scale is not None else 1.0
+ return flashinfer.prefill.trtllm_ragged_attention_deepseek(
+ query=q,
+ key=k,
+ value=v,
+ workspace_buffer=workspace_buffer,
+ seq_lens=actual_seq_lens_kv_device,
+ max_q_len=s_qo,
+ max_kv_len=s_kv,
+ bmm1_scale=_q_scale * _k_scale * scale,
+ bmm2_scale=_v_scale,
+ o_sf_scale=-1,
+ batch_size=batch_size,
+ window_left=-1,
+ cum_seq_lens_q=qo_indptr,
+ cum_seq_lens_kv=kv_indptr,
+ enable_pdl=False,
+ is_causal=causal,
+ return_lse=True,
+ out=trtllm_out,
+ backend="cute-dsl",
+ )[0]
elif backend == "cudnn":
# cuDNN uses wrapper API
return backend_wrappers[backend].run(q, k, v)
새로운 flashinfer/attention/cute_dsl/fmha.py 파일은 CuTe DSL FMHA 커널의 구체적인 구현을 담고 있으며, flashinfer/attention/cute_dsl/__init__.py는 이 모듈을 패키지 외부로 노출하는 역할을 합니다. is_cute_dsl_available() 함수를 통해 CuTe DSL 사용 가능 여부를 확인할 수 있습니다.
왜 이게 좋은가?
성능 향상
이 PR의 가장 큰 장점은 명확한 성능 향상입니다. 제공된 벤치마크 결과는 CuTe DSL 기반 FMHA 커널이 기존 trtllm-native 백엔드 대비 상당한 속도 향상을 보여줍니다. 특히, Blackwell GPU (SM100)에서 FP8 (E4M3) 및 BF16 데이터 타입으로 테스트한 결과, 다양한 시퀀스 길이와 배치 크기 조합에서 최대 17.4%의 속도 향상을 달성했습니다.
예시:
- FP8 e4m3 (D=128), 1×8K×32K shape:
cute-dsl7.666ms vstrtllm-native8.998ms (+17.4% Speedup) - FP8 e4m3 (D=192), 4×512×82K shape:
cute-dsl6.397ms vstrtllm-native7.286ms (+13.9% Speedup)
이러한 성능 향상은 다음과 같은 요인들에 기인합니다:
- 사전 컴파일된 커널 (Pre-compiled Kernels): JIT 컴파일 대신 미리 컴파일된
.so(cubin) 파일을 로드함으로써, 런타임 컴파일 오버헤드가 제거됩니다. 이는 특히 모델 로딩 시간 단축 및 추론 시 지연 시간 감소에 기여합니다. - CuTe DSL 최적화: CuTe DSL은 NVIDIA GPU 아키텍처에 특화된 고성능 커널을 생성하기 위해 설계되었습니다. 이를 통해 메모리 접근 패턴, 연산 병렬성 등을 최적화하여 하드웨어 성능을 최대한 활용할 수 있습니다.
- Blackwell 아키텍처 특화: 이 커널들은 Blackwell 아키텍처(SM100/SM103/SM110)의 새로운 기능과 성능 특성을 활용하도록 설계되어, 해당 하드웨어에서 최고의 성능을 발휘합니다.
- 다양한 기능 지원: FP16, BF16, FP8 (E4M3) 데이터 타입 지원, 가변 길이 ragged prefill, 스킵-소프트맥스(skip-softmax) 희소성, 인과적/비인과적 마스킹 등 다양한 고급 기능을 지원하면서도 높은 성능을 유지합니다.
일반적 교훈
이 PR은 다음과 같은 중요한 기술적 교훈을 제공합니다:
- 하드웨어 특화 커널의 중요성: 최신 하드웨어 아키텍처의 기능을 최대한 활용하기 위해서는 해당 하드웨어에 최적화된 커널을 사용하는 것이 필수적입니다. CuTe DSL과 같은 DSL은 이러한 커널 개발을 효율적으로 만들어 줍니다.
- 사전 컴파일 전략의 이점: 런타임 컴파일 오버헤드를 제거하기 위해 미리 컴파일된 바이너리(cubin, .so)를 배포하는 전략은 성능 민감 애플리케이션에서 효과적일 수 있습니다. 이는 특히 라이브러리 배포 시 JIT 컴파일 환경 설정의 복잡성을 줄여줍니다.
- API 통합의 중요성: 새로운 고성능 커널을 기존 프레임워크(여기서는 TensorRT-LLM의
trtllm_ragged_attention_deepseek())와 매끄럽게 통합하는 것은 라이브러리의 채택률을 높이는 데 중요합니다. 이를 통해 사용자는 복잡한 내부 구현을 알 필요 없이 성능 향상을 누릴 수 있습니다. - 체크섬 검증의 필요성: 외부에서 다운로드하는 바이너리 파일의 무결성을 보장하기 위해 체크섬 검증은 필수적인 보안 및 안정성 조치입니다.
리뷰어 피드백 반영
리뷰 과정에서 몇 가지 중요한 피드백이 있었고, PR은 이를 반영하여 개선되었습니다.
- API 연동: 리뷰어
nvpohanh는 이 PR이 DSR1 MLA prefill을 위한 것이라면trtllm_ragged_attention_deepseek()API와 연동되어야 한다고 제안했습니다. 이는 프레임워크에서 MLA prefill을 위해 사용하는 API이기 때문입니다. PR 제출자는 이를 성공적으로 통합했습니다. - 성능 지표 공유:
nvpohanh는 성능 지표 공유를 요청했고, PR 제출자는 상세한 벤치마크 결과를 제공했습니다. 이 결과는 CuTe DSL 커널의 우수한 성능을 입증합니다. - 명칭 통일:
trtllm-native대신trtllm-gen으로 백엔드 명칭을 통일하자는 제안이 있었고, 이는 반영되었습니다. - Front-padding 제거:
cute-dsl백엔드의front-padding요구사항은 커널 개발자가 다음 PR에서 제거할 예정임을 밝혔습니다. 이는 현재 커널의 제약사항이지만, 향후 개선될 부분입니다. - Device 인자:
get_cute_dsl_fmha_kernel함수에device인자를 포함해야 한다는 피드백이 있었고, 이는 수정되었습니다. 이는 멀티 디바이스 환경에서의 문제를 방지하기 위함입니다.
결론
이번 PR은 FlashInfer 라이브러리에 NVIDIA CuTe DSL 기반의 FMHA 커널을 통합함으로써, 특히 Blackwell 아키텍처에서의 사전 생성(Prefill) 성능을 크게 향상시킨 중요한 업데이트입니다. 사전 컴파일된 커널 로딩, 하드웨어 최적화, 그리고 기존 API와의 매끄러운 통합은 LLM 추론 성능을 한 단계 끌어올리는 데 기여합니다. 제공된 성능 지표는 이러한 최적화의 효과를 명확히 보여주며, 이는 LLM 개발자들에게 더 빠르고 효율적인 추론 환경을 제공할 것입니다.
References
- NVIDIA CuTe DSL
- FlashInfer GitHub Repository
- TensorRT-LLM Documentation
- torch.compile (참고: 이 PR은
torch.compile을 직접 사용하지 않지만, CUDA 커널 컴파일 및 최적화와 관련된 일반적인 맥락에서 관련될 수 있습니다.) - FlashInfer
trtllm_ragged_attention_deepseekAPI (리뷰어 피드백에서 언급된 API)
참고 자료
- https://docs.nvidia.com/cuda/cutlass/
- https://github.com/flashinfer-ai/flashinfer
- https://docs.nvidia.com/deeplearning/tensorrt-llm/
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/prefill.py#L3696
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [vllm] vLLM에 고성능 JIT 양자화 커널 'Humming' 도입하기
- 현재글 : [flashinfer] FlashInfer, CuTe DSL 기반 FMHA 커널 통합으로 사전 생성(Prefill) 성능 극대화
- 다음글 [flashinfer] FlashInfer 오토튜너 최적화: 하이브리드 토큰 버킷 도입
댓글