본문으로 건너뛰기

[vllm] vLLM, GDN Prefill 커널을 CuteDSL로 최적화하여 성능 향상

PR 링크: vllm-project/vllm#43273 상태: Merged | 변경: +3265 / -160

들어가며

최근 대규모 언어 모델(LLM)의 발전 속도는 눈부십니다. 특히 Mamba와 같은 새로운 아키텍처는 기존의 트랜스포머 모델과는 다른 방식으로 시퀀스 데이터를 처리하며 주목받고 있습니다. Mamba의 핵심 요소 중 하나인 Gated-Delta Net (GDN)은 선형 어텐션의 한 형태로, 긴 시퀀스 처리에서 효율성을 높이는 데 기여합니다. vLLM은 이러한 최신 기술을 빠르게 도입하고 최적화하는 데 앞장서고 있으며, 이번 PR은 vLLM의 GDN Prefill 연산에 새로운 CuteDSL 기반 커널을 도입하여 성능을 크게 향상시키는 것을 목표로 합니다.

이 PR은 기존의 GDN Prefill 커널의 병목 현상을 해결하고, 특히 Tensor Core 활용도를 높여 SM100 아키텍처에서 더 나은 성능을 달성하고자 합니다. 기존 FlashInfer 백엔드 대신, CuteDSL을 활용한 새로운 커널은 JIT(Just-In-Time) 컴파일의 이점을 살려 다양한 헤드 구성에 대한 적응성과 개발 용이성을 제공합니다. 이 글에서는 이 PR의 코드 변경 사항을 자세히 살펴보고, 어떤 점이 개선되었으며 왜 이것이 vLLM의 성능 향상에 중요한지 분석합니다.

코드 분석

이번 PR의 핵심은 vllm.model_executor.layers.mamba.ops.gdn_chunk_cutedsl 모듈에 새로운 GDN Prefill 커널을 추가하고, 이를 활용하도록 관련 코드를 수정하는 것입니다. 변경 사항은 크게 세 가지 커널 디자인과 관련 유틸리티 함수로 나눌 수 있습니다.

1. 커널 디자인: kkt_inv_uw_kernel, h_kernel, o_kernel

이 PR에서 도입된 커널은 기존 FLA(Flash Linear Attention)의 디자인을 따르면서도 최적화를 더했습니다. 주요 구성 요소는 다음과 같습니다:

  • kkt_inv_uw_kernel: FLA의 chunk_scaled_dot_kkt_fwd(), solve_tril(), recompute_w_u_fwd()를 융합한 커널입니다. 이 커널은 GDN의 핵심 병목 중 하나인 행렬 역행렬 계산을 Tensor Core를 활용하는 Newton-Schulz iteration 방식으로 개선했습니다.
  • h_kernel: FLA의 chunk_gated_delta_rule_fwd_h()에 해당합니다. H 커널은 V를 위한 패킹과 H를 위한 패킹 과정에서 발생하는 MMA(Matrix Multiply-Accumulate) 유휴 시간을 줄이는 것이 중요합니다.
  • o_kernel: FLA의 chunk_fwd_o()에 해당합니다. 이 커널은 상대적으로 Tensor Core 활용도가 높은 편입니다.

이러한 커널 설계는 Tensor Core 파이프라인의 버블(bubble)을 최소화하여 성능을 향상시키는 것을 목표로 합니다. 특히, kkt_inv_uw_kernel에서 행렬 역행렬 계산을 CUDA 코어가 아닌 Tensor Core로 처리하는 것은 큰 개선점입니다.

2. CuteDSL 및 관련 유틸리티

새로운 커널은 CuteDSL을 사용하여 구현되었습니다. CuteDSL은 CUTLASS 라이브러리의 DSL(Domain Specific Language)로, CUDA 커널 개발을 더 쉽고 효율적으로 만들어줍니다. 이번 PR에서는 다음과 같은 CuteDSL 관련 파일들이 추가되거나 수정되었습니다:

  • vllm/cute_utils/__init__.py: CuteDSL의 기본적인 연산(recast_val, simple_tma_copy, mma_bf16 등)을 래핑하는 유틸리티 함수들을 포함합니다. 특히 mma_bf16 함수는 BF16 데이터 타입에 대한 MMA 연산을 효율적으로 처리하도록 설계되었습니다.
  • vllm/cute_utils/_tcgen05.py: Tensor Core 관련 저수준 연산을 위한 유틸리티를 제공합니다. (이 파일은 diff에서 일부만 보이지만, 커널 구현에 필요한 저수준 기능을 제공할 것으로 예상됩니다.)
  • vllm/model_executor/layers/mamba/ops/gdn_chunk_cutedsl.py: 실제 GDN Prefill 커널 구현이 포함된 파일입니다. chunk_gated_delta_rule_cutedsl 함수가 핵심이며, prepare_metadata_cutedsl 함수는 커널 실행에 필요한 메타데이터를 준비합니다.
  • tests/kernels/mamba/test_gdn_prefill_cutedsl.py: 새로 추가된 커널의 정확성을 검증하기 위한 테스트 코드입니다. 다양한 시퀀스 길이와 배치 크기에 대해 기존 FLA 커널과의 정확도를 비교합니다.

3. 코드 변경 예시 (diff)

새로운 커널의 구현은 vllm/model_executor/layers/mamba/ops/gdn_chunk_cutedsl.py 파일에 주로 포함되어 있습니다. 테스트 파일 tests/kernels/mamba/test_gdn_prefill_cutedsl.py는 이 커널의 정확성을 검증합니다. 예를 들어, 테스트 파일의 일부는 다음과 같습니다:

--- /dev/null
+++ b/tests/kernels/mamba/test_gdn_prefill_cutedsl.py
@@ -0,0 +1,199 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import math
+
+import pytest
+import torch
+import torch.nn.functional as F
+
+from vllm.platforms import current_platform
+
+if not (
+    current_platform.is_cuda() and current_platform.is_device_capability_family(100)
+):
+    pytest.skip(
+        reason="GDN CuteDSL prefill requires CUDA SM10x.",
+        allow_module_level=True,
+    )
+
+from vllm.model_executor.layers.fla.ops import (  # noqa: E402
+    chunk_gated_delta_rule,
+) 
+from vllm.model_executor.layers.fla.ops.index import (  # noqa: E402
+    prepare_chunk_indices,
+    prepare_chunk_offsets,
+) 
+from vllm.model_executor.layers.mamba.ops.gdn_chunk_cutedsl import (  # noqa: E402
+    chunk_gated_delta_rule_cutedsl,
+    prepare_metadata_cutedsl,
+) 
+
+
+@pytest.mark.parametrize("num_seqs", [1, 5, 257])
+@pytest.mark.parametrize("state_dtype", [torch.bfloat16, torch.float32])
+def test_gdn_chunk_cutedsl_correctness(num_seqs: int, state_dtype: torch.dtype):
+    seq_lens = torch.randint(
+        1,
+        130,
+        (num_seqs,),
+        dtype=torch.int32,
+    )
+    cu_seqlens = torch.zeros(num_seqs + 1, device="cuda", dtype=torch.int32)
+    cu_seqlens[1:] = seq_lens.to(device="cuda").cumsum(0)
+    total_tokens = int(cu_seqlens[-1].item())
+
+    num_k_heads = 4
+    num_v_heads = 8
+    head_k_dim = 128
+    head_v_dim = 128
+    dtype = torch.bfloat16
+
+    q = torch.randn(
+        1,
+        total_tokens,
+        num_k_heads,
+        head_k_dim,
+        device="cuda",
+        dtype=dtype,
+    )
+    k = torch.randn_like(q)
+    v = torch.randn(
+        1,
+        total_tokens,
+        num_v_heads,
+        head_v_dim,
+        device="cuda",
+        dtype=dtype,
+    )
+    q = F.normalize(q.float(), p=2, dim=-1).to(dtype)
+    k = F.normalize(k.float(), p=2, dim=-1).to(dtype)
+    a = torch.randn(
+        1,
+        total_tokens,
+        num_v_heads,
+        device="cuda",
+        dtype=dtype,
+    )
+    b = torch.randn(
+        1,
+        total_tokens,
+        num_v_heads,
+        device="cuda",
+        dtype=dtype,
+    )
+    # Match upstream FLA GatedDeltaNet synthetic initialization:
+    # https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/gated_deltanet.py
+    A = torch.empty(num_v_heads, device="cuda", dtype=torch.float32).uniform_(0, 16)
+    A_log = torch.log(A)
+    dt = torch.exp(
+        torch.rand(num_v_heads, device="cuda", dtype=torch.float32)
+        * (math.log(0.1) - math.log(0.001))
+        + math.log(0.001)
+    )
+    dt = torch.clamp(dt, min=1e-4)
+    dt_bias = dt + torch.log(-torch.expm1(-dt))
+    g = -A_log.exp().view(1, 1, num_v_heads) * F.softplus(
+        a.float() + dt_bias.view(1, 1, num_v_heads)
+    )
+    beta = torch.sigmoid(b.float())
+    initial_state = (
+        torch.randn(
+            num_seqs,
+            num_v_heads,
+            head_v_dim,
+            head_k_dim,
+            device="cuda",
+            dtype=state_dtype,
+        )
+        * 0.05
+    )
+
+    # check metadata kernel
+    chunk_indices, chunk_offsets = prepare_metadata_cutedsl(cu_seqlens, total_tokens)
+    torch.accelerator.synchronize()
+
+    expected_indices = prepare_chunk_indices(cu_seqlens, 64)
+    expected_offsets = prepare_chunk_offsets(cu_seqlens, 64)
+    total_chunks = int(expected_offsets[-1].item())
+
+    torch.testing.assert_close(chunk_offsets, expected_offsets.to(torch.int32))
+    torch.testing.assert_close(
+        chunk_indices[:total_chunks],
+        expected_indices,
+    )
+
+    ref_o, ref_state = chunk_gated_delta_rule(
+        q=q,
+        k=k,
+        v=v,
+        g=g,
+        beta=beta,
+        initial_state=initial_state,
+        output_final_state=True,
+        cu_seqlens=cu_seqlens,
+        use_qk_l2norm_in_kernel=False,
+    )
+    actual_core_attn_out = torch.empty(
+        total_tokens,
+        num_v_heads,
+        head_v_dim,
+        device="cuda",
+        dtype=dtype,
+    )
+    actual_o, actual_state = chunk_gated_delta_rule_cutedsl(
+        q=q,
+        k=k,
+        v=v,
+        g=g,
+        beta=beta,
+        initial_state=initial_state,
+        cu_seqlens=cu_seqlens,
+        chunk_indices=chunk_indices,
+        chunk_offsets=chunk_offsets,
+        core_attn_out=actual_core_attn_out,
+    )
+    torch.accelerator.synchronize()
+
+    # check main kernel
+    o_error = (actual_o.float() - ref_o.float()).abs()
+    state_error = (
+        actual_state.float() - ref_state.to(actual_state.dtype).float()
+    ).abs()
+    assert o_error.max().item() < 2e-3
+    assert o_error.mean().item() < 6e-5
+    assert state_error.max().item() < 2e-2
+    assert state_error.mean().item() < 6e-4
+    core_attn_out_error = (
+        actual_core_attn_out.float() - actual_o.squeeze(0).float()
+    ).abs()
+    assert core_attn_out_error.max().item() == 0
+
+    # check main kernel when core_attn_out is not passed
+    no_buffer_o, no_buffer_state = chunk_gated_delta_rule_cutedsl(
+        q=q,
+        k=k,
+        v=v,
+        g=g,
+        beta=beta,
+        initial_state=initial_state,
+        cu_seqlens=cu_seqlens,
+        chunk_indices=chunk_indices,
+        chunk_offsets=chunk_offsets,
+    )
+    torch.accelerator.synchronize()
+
+    no_buffer_o_error = (no_buffer_o.float() - ref_o.float()).abs()
+    no_buffer_state_error = (
+        no_buffer_state.float() - ref_state.to(no_buffer_state.dtype).float()
+    ).abs()
+    buffer_o_error = (no_buffer_o.float() - actual_o.float()).abs()
+    buffer_state_error = (
+        no_buffer_state.float() - actual_state.to(no_buffer_state.dtype).float()
+    ).abs()
+    assert no_buffer_o_error.max().item() < 2e-3
+    assert no_buffer_o_error.mean().item() < 6e-5
+    assert no_buffer_state_error.max().item() < 2e-2
+    assert no_buffer_state_error.mean().item() < 6e-4
+    assert buffer_o_error.max().item() == 0
+    assert buffer_state_error.max().item() == 0

이 테스트 코드는 chunk_gated_delta_rule_cutedsl 함수가 기존의 chunk_gated_delta_rule 함수와 동일한 결과를 내는지, 그리고 core_attn_out 버퍼를 사용했을 때와 사용하지 않았을 때의 결과가 일치하는지를 검증합니다. 이는 새로운 커널의 정확성을 보장하는 중요한 단계입니다.

왜 이게 좋은가?

이번 PR은 여러 측면에서 vLLM의 GDN Prefill 성능을 향상시킵니다.

1. 성능 향상

마이크로벤치마크 결과에 따르면, 새로운 CuteDSL 기반 커널은 특히 Tensor Core 활용도가 낮은 시나리오, 예를 들어 헤드 수가 적은(TP) 설정이나 짧은 시퀀스 길이에서 기존 FlashInfer 백엔드보다 훨씬 뛰어난 성능을 보입니다. 예를 들어, Qwen3.5-397B-A17B TP4 모델과 1x8192 시퀀스 길이에서 기존 FLA 대비 3.46배의 속도 향상을 달성했습니다. BF16 상태를 사용할 경우에도 유사한 성능 향상이 관찰됩니다.

| Model                 | Seq      |   H |   Hv | FLA      | FlashInfer       | CuteDSL (this PR) |    FI MAE |   CuteDSL MAE |
|:----------------------|:---------|----:|-----:|:---------|:-----------------|:-----------------|----------:|--------------:|
| Qwen3.5-397B-A17B TP4 | 1x8192   |   4 |   16 | 0.839 ms | 0.440 ms (1.91x) | **0.243 ms (3.46x)** | 4.2e-05   |     3.32e-05  |

또한, E2E(End-to-End) 속도 측정에서도 긍정적인 결과가 나타났습니다. Qwen3.6-27B TP1 모델에서 BF16 상태를 사용할 경우, 기존 FLA 대비 7.3% 더 높은 처리량(tok/s)을 보였으며, TTFT(Time To First Token)도 개선되었습니다.

Backend | State dtype | TPGS | TTFT | TPOT
---|---|---|---|---
Trition/FLA | BF16 | 13977 tok/s | 67.79 s | 102.55 ms
FlashInfer | BF16 | 13829 tok/s (-1.1% from FLA) | 80.96 s | 99.87 ms
CuteDSL (this PR) | BF16 | 14993 tok/s (+7.3% from FLA) | 75.40 s | 92.17 ms

2. Tensor Core 활용도 개선

기존 GDN 구현의 주요 병목 중 하나는 행렬 역행렬 계산이었습니다. 이 PR에서는 CUDA 코어를 사용하는 대신, Tensor Core를 활용하는 Newton-Schulz iteration 방식을 도입했습니다. 이는 Tensor Core의 활용도를 크게 높여 연산 속도를 개선하는 핵심 요인입니다.

3. BF16 네이티브 지원

이 커널은 BF16 상태에 대한 네이티브 지원을 제공합니다. 즉, BF16 데이터를 FP32로 변환하는 추가적인 오버헤드 없이 직접 연산을 수행할 수 있어, 메모리 대역폭과 연산 효율성을 높입니다.

4. 개발 용이성 및 JIT 이점

CuteDSL을 사용함으로써 개발자는 더 높은 수준의 추상화에서 커널을 작성할 수 있으며, JIT 컴파일을 통해 다양한 하드웨어 및 모델 구성(예: 헤드 구성)에 대한 최적화를 자동으로 적용받을 수 있습니다. 이는 향후 vLLM의 다양한 모델 지원 및 성능 최적화에 유리하게 작용할 것입니다.

5. 일반적인 교훈

  • 커널 융합 및 재설계: 여러 연산을 하나의 커널로 융합하고, 병목 구간을 식별하여 Tensor Core 친화적인 알고리즘으로 재설계하는 것은 성능 향상의 핵심입니다.
  • DSL 활용: CuteDSL과 같은 DSL은 복잡한 GPU 커널 개발을 단순화하고, 이식성과 유지보수성을 높이는 데 기여합니다.
  • 데이터 타입 최적화: BF16과 같은 저정밀도 데이터 타입을 네이티브로 지원하는 것은 LLM 추론 성능에 큰 영향을 미칩니다.
  • 정확도 검증: 새로운 커널 도입 시, 기존 구현과의 정확도 차이를 엄격하게 검증하는 것이 필수적입니다. 이 PR에서는 FLA를 기준으로 Mean Absolute Error (MAE) 및 Max Absolute Error를 비교하여 정확도를 보장했습니다.

리뷰 피드백 반영

리뷰 과정에서 몇 가지 중요한 피드백이 있었습니다:

  • gdn_prefill_bench.py 스크립트 활용: 리뷰어(arpera)는 제안된 벤치마크 스크립트를 사용하여 성능을 평가하도록 요청했습니다. PR 작성자는 이미 유사한 마이크로벤치마크를 수행했음을 밝혔고, 결과가 긍정적이어서 추가적인 스크립트 실행은 불필요하다고 판단했습니다.
  • E2E 성능 측정: 다양한 모델 구성(Qwen3.5-397B-A17B-NVFP4, DEP8 토폴로지 등)에 대한 E2E 성능 측정이 요청되었습니다. PR 작성자는 GB200에서 DEP4 토폴로지를 사용한 결과를 공유했으나, 결과가 노이즈 수준이어서 명확한 우위를 보이지는 못했습니다. 이는 통신 오버헤드(NCCL AllReduce)가 GDN 커널 자체의 성능 개선을 상쇄할 수 있음을 시사합니다.
  • 코드 중복 제거: vllm/v1/attention/backends/gdn_attn.py 파일에서 코드 중복이 지적되어 유틸리티 함수로 분리하는 작업이 완료되었습니다.
  • 정확도 검증: vllm/model_executor/layers/mamba/gdn/qwen_gdn_linear_attn.py 파일에서 정확도 검증 로직이 중복되어 한 번만 수행하도록 수정되었습니다.

이러한 피드백은 PR의 완성도를 높이는 데 기여했습니다.

결론

이번 vLLM의 PR은 CuteDSL을 활용하여 GDN Prefill 연산을 위한 새로운 커널을 성공적으로 도입했습니다. 이 커널은 Tensor Core 활용도를 높이고, BF16 상태를 네이티브로 지원하며, 개발 용이성을 개선함으로써 vLLM의 추론 성능을 크게 향상시킵니다. 특히 헤드 수가 적거나 시퀀스 길이가 긴 워크로드에서 그 효과가 두드러집니다. 향후 이 커널은 KDA(Key-Value Attention) 지원, MMA 파이프라인 스톨링 개선 등 추가적인 최적화를 통해 더욱 발전할 가능성이 있습니다. 이는 vLLM이 최신 LLM 아키텍처를 효율적으로 지원하고 성능을 지속적으로 개선해나가는 중요한 발걸음입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글