[flashinfer] FlashInfer, CUDA 그래프 호환성을 높이고 성능을 최적화하다: TRT-LLM FMHA v2 통합 및 불필요한 H2D 제거
PR 링크: flashinfer-ai/flashinfer#2841 상태: Merged | 변경: +None / -None
들어가며
최근 LLM(거대 언어 모델)의 발전 속도는 눈부십니다. 이러한 모델의 효율적인 서빙을 위해서는 고성능 추론 엔진이 필수적이며, FlashInfer는 그 선두에 서 있는 라이브러리 중 하나입니다. 이번 PR은 FlashInfer가 NVIDIA의 TensorRT-LLM(TRT-LLM)에서 제공하는 최신 FMHA(Fused Multi-Head Attention) v2 커널을 통합하고, 동시에 CUDA 그래프(CUDA Graph) 환경에서의 호환성을 높여 성능을 개선하는 데 중점을 두고 있습니다.
특히, 이 PR은 다음과 같은 핵심적인 문제들을 해결합니다:
- TRT-LLM FMHA v2 통합: 최신 FMHA v2 커널을 FlashInfer 벤치마크 프레임워크에 통합하여 성능 비교 및 검증을 가능하게 합니다.
- CUDA 그래프 호환성 개선: 기존 코드에서 CUDA 그래프 래핑 시 발생할 수 있었던 비호환성 문제를 해결하여, 더 넓은 범위의 워크로드에서 FlashInfer를 활용할 수 있도록 합니다.
- 불필요한 Host-to-Device (H2D) 전송 제거: 메모리 대역폭을 절약하고 지연 시간을 줄이기 위해 불필요한 데이터 전송을 제거합니다.
이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 이러한 변경이 왜 성능 향상과 CUDA 그래프 호환성 증대에 기여하는지 기술적인 관점에서 설명하겠습니다.
코드 분석
이번 PR은 주로 benchmarks/routines/attention.py, benchmarks/routines/flashinfer_benchmark_utils.py, 그리고 TRT-LLM의 일부 커널 소스 코드(vendored)에 걸쳐 변경 사항을 포함하고 있습니다. 각 파일별 주요 변경 내용을 살펴보겠습니다.
1. benchmarks/routines/attention.py - 벤치마크 통합 및 CUDA 그래프 호환성 개선
이 파일은 FlashInfer의 다양한 어텐션 구현을 벤치마킹하는 핵심 로직을 담고 있습니다. 이번 PR에서는 trtllm-fmha-v2 백엔드를 새롭게 추가하고, 기존 코드의 CUDA 그래프 호환성을 높이기 위한 수정이 이루어졌습니다.
1.1. trtllm-fmha-v2 백엔드 추가
BatchPrefillWithPagedKVCacheWrapper 및 BatchPrefillWithRaggedKVCacheWrapper 함수 내에서 지원하는 백엔드 목록에 trtllm-fmha-v2가 추가되었습니다. 이는 새로운 FMHA v2 커널을 벤치마크에서 선택하고 실행할 수 있게 합니다.
--- a/benchmarks/routines/attention.py
+++ b/benchmarks/routines/attention.py
@@ -22,6 +22,8 @@
if not is_lib_missing:
raise
from flashinfer.fp4_quantization import nvfp4_quantize_paged_kv_cache
+from flashinfer.prefill import trtllm_fmha_v2_prefill
+from flashinfer.utils import is_sm12x_supported
from flashinfer.testing.utils import (
attention_tb_per_sec_with_actual_seq_lens,
attention_tflops_per_sec_with_actual_seq_lens,
@@ -111,6 +113,7 @@ def parse_attention_args(line, parser):
"cutlass",
"trtllm-gen",
"trtllm-native",
+ "trtllm-fmha-v2",
"trtllm-gen-native", # Deprecated, will be removed in future
"cute-dsl",
],
@@ -936,6 +939,9 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
remove_trtllm_native = True
if remove_trtllm_native:
backends.remove("trtllm-native")
+ if "trtllm-fmha-v2" in backends and is_nvfp4_kv:
+ print("[INFO] trtllm-fmha-v2 backend does not support NVFP4. Skipping.")
+ backends.remove("trtllm-fmha-v2")
if "cutlass" in backends:
print("[INFO] CUTLASS backend does not support prefill. Skipping.")
@@ -1072,7 +1078,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
.to(device)
)
- # Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr
+ # Page-based indptr for FlashInfer paged attention (cumulative page counts)
kv_indptr = (
torch.cat(
[
@@ -1086,6 +1092,17 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
.int()
.to(device)
)
+ # Token-based indptr for TRT-LLM backends (cumulative token counts)
+ kv_token_indptr = (
+ torch.cat(
+ [
+ torch.tensor([0], device=device),
+ torch.cumsum(actual_seq_lens_kv_device.flatten(), dim=0),
+ ]
+ )
+ .int()
+ .to(device)
+ )
kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32)
for i in range(len(kv_indptr) - 1):
start_idx = kv_indptr[i]
@@ -1158,6 +1175,14 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
v_quantized, _ = to_float8(v_data, kv_dtype)
kv_cache = torch.cat([k_quantized, v_quantized], dim=1)
+ # Ensure trtllm-fmha-v2 sees contiguous HND-physical paged KV cache.
+ # Skip if kv_cache is not a plain Tensor (e.g., NVFP4 packed tuple).
+ # backend filter further down also drops trtllm-fmha-v2 in that case.
+ if "trtllm-fmha-v2" in backends and isinstance(kv_cache, torch.Tensor):
+ _fmha_v2_kv_cache = kv_cache.contiguous()
+ else:
+ _fmha_v2_kv_cache = kv_cache
+
# Prepare wrappers (after FP8 conversion so we have correct dtypes)
backend_wrappers = {}
resolved_backends = {}
@@ -1305,6 +1330,26 @@ def run_backend_wrapper(
v_scale=v_scale_tensor,
o_data_type=o_data_type,
)[0]
+ elif backend == "trtllm-fmha-v2":
+ _q_scale = q_scale if q_scale is not None else 1.0
+ _k_scale = k_scale if k_scale is not None else 1.0
+ _fmha_v2_bmm2_scale = v_scale if v_scale is not None else 1.0
+ return trtllm_fmha_v2_prefill(
+ qkv=(q, _fmha_v2_kv_cache),
+ input_layout="Q_PAGED_KV_HND",
+ workspace_buffer=workspace_buffer,
+ seq_lens=actual_seq_lens_kv_device.flatten(),
+ max_q_len=s_qo,
+ max_kv_len=s_kv,
+ bmm1_scale=_q_scale * _k_scale * scale,
+ bmm2_scale=_fmha_v2_bmm2_scale,
+ batch_size=batch_size,
+ cum_seq_lens_q=qo_indptr,
+ cum_seq_lens_kv=kv_token_indptr,
+ block_tables=block_tables,
+ mask_mode="causal" if causal else "padding",
+ out_dtype=o_data_type,
+ )
else:
print(f"[ERROR] Backend {backend} not supported")
return None
@@ -1367,9 +1412,15 @@ def run_backend_wrapper(
tested_outputs = list(outputs.values())
# When cases where FA2 is not available, try to find an alternative reference
- # Priority: cudnn > cudnn-native > trtllm-gen > trtllm-native
+ # Priority: cudnn > cudnn-native > trtllm-gen > trtllm-native > trtllm-fmha-v2
if run_refcheck and not has_reference_output and len(tested_backends) > 1:
- reference_priority = ["cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"]
+ reference_priority = [
+ "cudnn",
+ "cudnn-native",
+ "trtllm-gen",
+ "trtllm-native",
+ "trtllm-fmha-v2",
+ ]
for candidate in reference_priority:
if candidate in tested_backends:
has_reference_output = True
또한, Paged KV Cache를 사용할 때 TRT-LLM 백엔드는 토큰 기반의 indptr를 기대하는 반면, FlashInfer의 래퍼는 페이지 기반 indptr를 사용합니다. 이 간극을 해소하기 위해 kv_token_indptr가 새롭게 계산되어 trtllm_fmha_v2_prefill 함수에 전달됩니다.
--- a/benchmarks/routines/attention.py
+++ b/benchmarks/routines/attention.py
@@ -1072,7 +1078,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
.to(device)
)
- # Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr
+ # Page-based indptr for FlashInfer paged attention (cumulative page counts)
kv_indptr = (
torch.cat(
[
@@ -1086,6 +1092,17 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
.int()
.to(device)
)
+ # Token-based indptr for TRT-LLM backends (cumulative token counts)
+ kv_token_indptr = (
+ torch.cat(
+ [
+ torch.tensor([0], device=device),
+ torch.cumsum(actual_seq_lens_kv_device.flatten(), dim=0),
+ ]
+ )
+ .int()
+ .to(device)
+ )
kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32)
for i in range(len(kv_indptr) - 1):
start_idx = kv_indptr[i]
1.2. CUDA 그래프 호환성 개선을 위한 KV Cache 처리
CUDA 그래프는 커널 실행 시점에 동적으로 할당되는 메모리나 Host-to-Device (H2D) 복사를 캡처할 수 없습니다. 이 PR은 이러한 제약을 해결하기 위해 Paged KV Cache 처리를 개선했습니다.
기존에는 kv_cache.contiguous()를 항상 호출하여 trtllm-fmha-v2 백엔드가 기대하는 물리적 레이아웃과 일치시키려 했으나, 이는 NVFP4와 같이 kv_cache가 torch.Tensor가 아닌 경우(isinstance(kv_cache, torch.Tensor) 조건 추가) 문제를 일으킬 수 있었습니다. 또한, contiguous() 호출 자체가 CUDA 그래프 캡처 시점에 문제를 일으킬 수 있습니다. 이를 해결하기 위해 kv_cache가 torch.Tensor일 경우에만 .contiguous()를 적용하도록 수정했습니다.
--- a/benchmarks/routines/attention.py
+++ b/benchmarks/routines/attention.py
@@ -1158,6 +1175,14 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
v_quantized, _ = to_float8(v_data, kv_dtype)
kv_cache = torch.cat([k_quantized, v_quantized], dim=1)
+ # Ensure trtllm-fmha-v2 sees contiguous HND-physical paged KV cache.
+ # Skip if kv_cache is not a plain Tensor (e.g., NVFP4 packed tuple).
+ # backend filter further down also drops trtllm-fmha-v2 in that case.
+ if "trtllm-fmha-v2" in backends and isinstance(kv_cache, torch.Tensor):
+ _fmha_v2_kv_cache = kv_cache.contiguous()
+ else:
+ _fmha_v2_kv_cache = kv_cache
+
# Prepare wrappers (after FP8 conversion so we have correct dtypes)
backend_wrappers = {}
resolved_backends = {}
2. benchmarks/routines/flashinfer_benchmark_utils.py - Compute Capability별 백엔드 지원 명시
이 파일은 특정 CUDA Compute Capability(SM 버전)에서 어떤 백엔드가 지원되는지를 정의합니다. 이번 PR에서는 SM90 (Ampere) 및 SM120 (Hopper)에서 trtllm-fmha-v2 백엔드를 지원하도록 업데이트되었습니다.
--- a/benchmarks/routines/flashinfer_benchmark_utils.py
+++ b/benchmarks/routines/flashinfer_benchmark_utils.py
@@ -320,25 +320,27 @@ def dtype_str_to_torch_dtype(dtype_str):
},
"BatchPrefillWithPagedKVCacheWrapper": {
# NOTE: trtllm-native calls trtllm_batch_context_with_kv_cache
+ # NOTE: trtllm-fmha-v2 calls trtllm_fmha_v2_prefill
# NOTE: cudnn-native calls cudnn_batch_prefill_with_kv_cache
"7.5": [],
"8.0": ["fa2", "auto", "cudnn", "cudnn-native"],
"8.6": ["fa2", "auto", "cudnn", "cudnn-native"],
"8.9": ["fa2", "auto", "cudnn", "cudnn-native"],
- "9.0": ["fa2", "fa3", "auto", "cudnn", "cudnn-native"],
+ "9.0": ["fa2", "fa3", "auto", "cudnn", "cudnn-native", "trtllm-fmha-v2"],
"10.0": ["fa2", "auto", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"],
"10.3": ["fa2", "auto", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"],
- "12.0": ["fa2", "auto", "cudnn", "cudnn-native"],
+ "12.0": ["fa2", "auto", "cudnn", "cudnn-native", "trtllm-fmha-v2"],
"12.1": ["fa2", "auto", "cudnn", "cudnn-native"],
},
"BatchPrefillWithRaggedKVCacheWrapper": {
# NOTE: trtllm-native calls trtllm_ragged_attention_deepseek
+ # NOTE: trtllm-fmha-v2 calls trtllm_fmha_v2_prefill
# NOTE: cudnn-native calls cudnn_batch_prefill_with_kv_cache
"7.5": [],
"8.0": ["fa2", "cudnn", "cudnn-native"],
"8.6": ["fa2", "cudnn", "cudnn-native"],
"8.9": ["fa2", "cudnn", "cudnn-native"],
- "9.0": ["fa2", "fa3", "cudnn", "cudnn-native"],
+ "9.0": ["fa2", "fa3", "cudnn", "cudnn-native", "trtllm-fmha-v2"],
"10.0": [
"fa2",
"cudnn",
@@ -355,7 +357,7 @@ def dtype_str_to_torch_dtype(dtype_str):
"cute-dsl",
"trtllm-native",
],
- "12.0": ["fa2", "cudnn", "cudnn-native"],
+ "12.0": ["fa2", "cudnn", "cudnn-native", "trtllm-fmha-v2"],
"12.1": ["fa2", "cudnn", "cudnn-native"],
},
"BatchMLAPagedAttentionWrapper": {
3. TRT-LLM 커널 소스 코드 수정 (Vendored Code)
이 PR은 TRT-LLM의 일부 커널 소스 코드를 직접 수정(vendored)하여 FlashInfer 내에서 사용하고 있습니다. 이는 TRT-LLM의 최신 기능을 활용하면서도 FlashInfer의 요구사항에 맞게 최적화하기 위함입니다. 주요 수정 내용은 다음과 같습니다.
3.1. scale_bmm2의 Host-to-Device (H2D) 전송 제거
기존에는 scale_bmm2 값이 커널 실행 시점에 GPU 메모리에 복사되어야 했습니다. 이 과정은 호스트에서 디바이스로의 비동기 복사(cudaMemcpyAsync)를 필요로 했으며, 이는 CUDA 그래프 캡처 시점에 문제를 일으켰습니다. 왜냐하면 복사될 호스트 포인터가 래퍼 함수가 반환된 후 유효하지 않게 될 수 있기 때문입니다.
이번 PR에서는 SM90 Warp-Specialized (WS) 에필로그가 params.scale_bmm2 값을 커널 인자 버퍼에서 직접 읽도록 수정되었습니다. params.scale_bmm2는 uint32 값으로 커널 인자 버퍼에 값 자체로 캡처되므로, 커널 재실행 시 새로운 scale 값을 자연스럽게 반영할 수 있습니다. 이 변경은 SM90 WS 에필로그에만 국한되며, 다른 아키텍처(SM80 등)는 영향을 받지 않습니다.
--- a/csrc/fmha_v2/fmha/warpspec/epilogue.h
+++ b/csrc/fmha_v2/fmha/warpspec/epilogue.h
@@ -1133,7 +1133,7 @@
const auto scale_bmm2_d = params.scale_bmm2_d;
const auto scale_bmm2 = params.scale_bmm2;
- const float scale = (scale_bmm2_d == nullptr) ? scale_bmm2 : *scale_bmm2_d;
+ const float scale = (scale_bmm2_d == nullptr) ? scale_bmm2 : scale_bmm2;
// If the output is FP8, we need to scale the output by 1/scale.
// This is because the epilogue is responsible for scaling the output
3.2. softmax_stats 할당 조건화
이전 코드에서는 softmax_stats에 대한 워크스페이스 할당이 항상 이루어졌습니다. 하지만 softmax_stats_ptr가 nullptr인 경우(즉, 호출자가 출력 텐서를 제공하지 않는 경우)에는 할당이 불필요합니다. 이번 수정으로 불필요한 할당이 제거되어, CUDA 그래프 환경에서 더 깔끔하게 동작하게 됩니다.
3.3. expanded_block_tables 제거
Paged KV Cache를 사용할 때, trtllm_fmha_v2_prefill은 [B, M] 형태의 논리적 페이지 인덱스를 [B, 2, M] 형태의 K/V 풀 오프셋으로 확장하는 작업을 매 호출마다 수행했습니다. 이 과정은 torch.stack([bt*2, bt*2+1], dim=1).contiguous()와 같은 연산을 포함하며, 매 호출마다 2~4개의 임시 GPU 텐서를 할당하고 메모리 압박 시 캐싱 할당기 동기화를 유발할 수 있었습니다.
이 PR에서는 Kv_block_array에 mUsesSharedPagedKvIdx 플래그를 추가하고, Ampere 및 Hopper 아키텍처의 커널이 블록 오프셋 로드 중에 인터리브된 풀 오프셋(page_idx * 2 + kv_type)을 온디맨드(on-the-fly)로 계산하도록 변경했습니다. 이를 통해 [B, M] 형태의 block_tables가 커널에 직접 전달되며, 불필요한 텐서 할당과 복사가 제거됩니다.
--- a/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h
+++ b/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h
@@ -863,11 +863,21 @@ struct Gmem_tile_paged_kv {
paged_kv_log2_block_size_(params.paged_kv_cache.mTokensPerBlockLog2),
paged_kv_block_pool_ptr_(reinterpret_cast<char*>(params.paged_kv_cache.mPoolPtr)),
paged_kv_global_block_offsets_(params.paged_kv_cache.mBlockOffsets),
- params_kv_block_size_in_bytes_(params.paged_kv_cache.mBytesPerBlock) {
- // Handle Paged KV with shape [S, Dh],
+ params_kv_block_size_in_bytes_(params.paged_kv_cache.mBytesPerBlock),
+ mUsesSharedPagedKvIdx(params.paged_kv_cache.mUsesSharedPagedKvIdx) {
+ // Handle Paged KV with shape [S, Dh],
// If mUsesSharedPagedKvIdx is true, then block_tables is [B, M]
// and we need to compute the interleaved pool offset on-the-fly.
+ // Otherwise, block_tables is [B, 2, M] and we can use it directly.
+ // The default value for mUsesSharedPagedKvIdx is false.
+ if (mUsesSharedPagedKvIdx) {
+ // Check if the input block_tables is [B, M]
+ if (block_tables.ndim() != 2) {
+ throw std::runtime_error("block_tables must be [B, M] when mUsesSharedPagedKvIdx is true");
+ }
+ }
+
// If mUsesSharedPagedKvIdx is true, then block_tables is [B, M]
// and we need to compute the interleaved pool offset on-the-fly.
// Otherwise, block_tables is [B, 2, M] and we can use it directly.
3.4. scale_softmax = 1.0 강제 적용
scale_softmax 값이 0.0일 때, C++ 코드에서 자동으로 스케일을 감지하는 로직이 FP16/INT8/E4M3 데이터 타입에는 잘 동작하지만, BF16 데이터 타입에서는 소프트맥스 출력을 0으로 만드는 버그가 있었습니다. 이를 방지하기 위해 scale_softmax 값을 항상 1.0으로 명시적으로 전달하도록 수정되었습니다.
왜 이게 좋은가?
이번 PR의 변경 사항들은 여러 측면에서 성능 향상과 유용성 증대에 기여합니다.
- CUDA 그래프 호환성 향상: 가장 중요한 개선점 중 하나입니다. CUDA 그래프는 동적 할당 및 H2D 복사를 캡처하지 못하므로, 이러한 비호환성 요소를 제거하는 것은 프로덕션 환경에서 FlashInfer를 더욱 견고하게 사용할 수 있게 합니다. 특히
scale_bmm2의 H2D 전송 제거와expanded_block_tables로직의 온디맨드 계산은 CUDA 그래프 래핑 시 발생하는 주요 병목 현상을 해결합니다. - 메모리 대역폭 및 지연 시간 감소:
expanded_block_tables를 제거함으로써 매 호출마다 발생하는 불필요한 텐서 할당과 GPU 메모리 복사가 사라집니다. 이는 특히 긴 시퀀스나 많은 배치 사이즈에서 상당한 성능 향상을 가져올 수 있습니다. 또한,scale_bmm2의 H2D 전송 제거는 GPU-CPU 간 통신 오버헤드를 줄여줍니다. - 최신 하드웨어 및 커널 활용:
trtllm-fmha-v2백엔드를 통합함으로써 사용자는 NVIDIA의 최신 FMHA 구현을 FlashInfer 프레임워크 내에서 직접 활용하고 성능을 비교할 수 있게 되었습니다. 이는 최신 GPU 아키텍처(예: Hopper)에서 최적의 성능을 달성하는 데 도움이 됩니다. - 벤치마크 및 테스트 강화: 새로운 백엔드 통합은 벤치마크 스위트를 더욱 포괄적으로 만들어, 다양한 구현 간의 성능을 정확하게 비교할 수 있게 합니다. 또한,
trtllm-fmha-v2의 입력 형식 요구사항(예:contiguous()KV cache, 토큰 기반indptr)을 처리함으로써 테스트 커버리지를 넓혔습니다.
성능 수치: PR 설명에는 구체적인 성능 수치가 명시적으로 포함되어 있지 않지만, 리뷰어의 댓글에 첨부된 nsys 트레이스 이미지는 H2D 복사 및 작은 커널들이 제거되었음을 시각적으로 보여줍니다. 이는 성능 개선의 간접적인 증거가 됩니다.
일반적인 교훈: 이 PR은 다음과 같은 일반적인 최적화 교훈을 제공합니다:
- CUDA 그래프 제약 조건 이해: CUDA 그래프를 지원해야 하는 라이브러리에서는 동적 할당, H2D 복사, 포인터 유효성 문제 등을 신중하게 고려해야 합니다.
- 불필요한 메모리 할당 및 복사 제거: 벤치마크 및 프로파일링 도구를 사용하여 런타임에 발생하는 불필요한 메모리 작업을 식별하고 제거하는 것은 성능 향상의 핵심입니다.
- 커널 입력 형식 표준화: 다양한 백엔드를 지원할 때, 각 백엔드의 입력 형식 요구사항을 정확히 파악하고, 필요한 경우 데이터 변환 로직을 효율적으로 구현해야 합니다.
- 벤더 라이브러리 통합 시 주의: 외부 라이브러리(여기서는 TRT-LLM)를 직접 수정하여 사용할 때는, 업스트림 변경 사항과의 동기화(rebase) 전략을 명확히 하고, 변경 사항을 명확히 문서화해야 합니다 (예:
DIVERGENCE주석).
리뷰어 피드백 반영
리뷰어들의 피드백은 이 PR의 완성도를 높이는 데 중요한 역할을 했습니다.
scale_bmm2H2D 제거 확인: Akhil G.은scale_bmm2의 H2D 전송 제거 및 작은 커널 제거를 확인하기 위해nsys트레이스를 요청했습니다. Jimmy Z.가 제공한 트레이스 이미지는 이러한 변경이 성공적으로 이루어졌음을 보여줍니다.- Ragged 테스트 및 FP4 호환성: Salty Minty는 Ragged 테스트에서 SM120의
SEPARATE_Q_K_V폴백 및 FP4 KV Cache 사용 시.contiguous()호출이 문제를 일으킬 수 있음을 지적했습니다. 이에 대한 수정으로isinstance(kv_cache, torch.Tensor)체크가 추가되었고, FP4 사용 시trtllm-fmha-v2가 자동으로 필터링되도록 개선되었습니다. is_paged_hnd제거: C++ 코드에서 사용되지 않는 인자가 제거되었습니다.
이러한 피드백은 코드의 정확성, 견고성, 그리고 CUDA 그래프 호환성을 더욱 강화하는 데 기여했습니다.
결론
이번 PR은 FlashInfer가 TRT-LLM의 최신 FMHA v2 커널을 성공적으로 통합하고, CUDA 그래프 환경에서의 호환성을 크게 향상시킨 중요한 업데이트입니다. 불필요한 H2D 전송 제거와 동적 메모리 할당 로직 최적화를 통해 성능 또한 개선되었습니다. 이러한 변경은 LLM 추론의 효율성을 높이고, 더 넓은 범위의 하드웨어 및 소프트웨어 환경에서 FlashInfer를 유용하게 사용할 수 있도록 만들 것입니다.
참고 자료
- https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fmhaKernels.h
- https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fmha_v2/fmha/warpspec/epilogue.h
- https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fmha_v2/fmha/gmem_tile_qkv_packed.h
- https://pytorch.org/docs/stable/generated/torch.cumsum.html
- https://pytorch.org/docs/stable/generated/torch.stack.html
- https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang: MiniMax-M2.5 MoE 모델을 위한 FP8 FlashInfer TRT-LLM 라우팅 최적화
- [flashinfer] FlashInfer: Wide Vector 최적화와 1900줄의 코드 삭제로 달성한 성능 개선
- [flashinfer] FlashInfer, CuTe DSL 기반 FMHA 커널 통합으로 사전 생성(Prefill) 성능 극대화
- [flashinfer] FlashInfer의 DiT 최적화: SageAttention과 Int8/FP8 혼합 정밀도 커널 도입 분석
- [vllm] vLLM, DCP A2A 어텐션 백엔드 최적화: 단일 All-to-All 콜렉티브로 성능 향상
PR Analysis 의 다른글
- 이전글 [flashinfer] FlashInfer: Wide Vector 최적화와 1900줄의 코드 삭제로 달성한 성능 개선
- 현재글 : [flashinfer] FlashInfer, CUDA 그래프 호환성을 높이고 성능을 최적화하다: TRT-LLM FMHA v2 통합 및 불필요한 H2D 제거
- 다음글 [vllm] vLLM, Gemma 4 모델에 양자화된 Speculative Decoding 적용: 성능 향상의 비밀
댓글