[vllm] vLLM, DeepSeek-V4 K 캐시 커널 최적화: CuteDSL 도입으로 성능 향상
PR 링크: vllm-project/vllm#42236 상태: Merged | 변경: +0 / -0
들어가며
최근 대규모 언어 모델(LLM)의 발전 속도가 매우 빠르며, 이러한 모델들을 효율적으로 서빙하는 것은 중요한 과제입니다. 특히, 추론 시 발생하는 메모리 병목 현상은 모델의 응답 속도를 저하시키는 주요 원인 중 하나입니다. vLLM은 이러한 문제를 해결하기 위해 다양한 최적화 기법을 도입해왔습니다. 본 블로그 글에서는 vLLM 프로젝트의 PR([DSv4] Improved dequant gather K cache kernel)을 분석하여, DeepSeek-V4 모델의 K 캐시(Key Cache)를 처리하는 커널의 성능을 개선한 내용을 자세히 살펴보겠습니다. 이 PR은 특히 dequantize_and_gather_k_cache 함수의 메모리 대역폭 활용도를 높이는 데 초점을 맞추고 있으며, 이를 위해 CuteDSL이라는 새로운 도구를 도입했습니다.
코드 변경 분석
이번 PR의 핵심은 dequantize_and_gather_k_cache 함수의 성능을 개선하는 것입니다. 기존에는 Triton으로 구현된 커널을 사용했지만, 특정 상황에서 메모리 대역폭 활용도가 낮다는 문제가 발견되었습니다. 이를 해결하기 위해 CuteDSL을 활용한 새로운 구현이 추가되었으며, 기존 Triton 구현은 dequantize_and_gather_k_cache_triton으로 이름이 변경되고, 새로운 CuteDSL 구현은 dequantize_and_gather_k_cache_cutedsl로 분리되었습니다. 최종적으로 dequantize_and_gather_k_cache 함수는 has_cutedsl() 조건에 따라 두 구현 중 하나를 선택하도록 변경되었습니다.
1. tests/kernels/test_compressor_kv_cache.py
테스트 파일에서는 새로운 기능에 대한 검증 로직이 추가되었습니다. 특히, Test B: Fused dequant+gather K cache 섹션이 신설되어 dequantize_and_gather_k_cache 함수의 정확성을 검증하는 테스트 케이스가 추가되었습니다. 또한, 기존의 테스트 케이스들이 A, B, C, D에서 A, B, C, D, E로 재구성되었습니다.
Before (기존 테스트 구조 일부):
@@ -134,7 +140,140 @@ def test_deepseek_v4_attention_quant_cache_roundtrip(num_tokens: int, block_size
)
-# ── Test B: Indexer path ────────────────────────────────────────────────────
+# ── Test B: Fused dequant+gather K cache ────────────────────────────────────
+
+
+def _dequantize_and_gather_k_cache_reference(
+ out: torch.Tensor,
+ k_cache: torch.Tensor,
+ seq_lens: torch.Tensor,
+ gather_lens: torch.Tensor | None,
+ block_table: torch.Tensor,
+ block_size: int,
+ offset: int,
+) -> None:
+ fp8_dim = 448
+ bf16_dim = 64
+ scale_dim = 8
+ quant_block = 64
+ token_data_size = fp8_dim + bf16_dim * 2
+
+ for req_id in range(seq_lens.shape[0]):
+ seq_len = seq_lens[req_id].item()
+ gather_len = gather_lens[req_id].item() if gather_lens is not None else seq_len
+ start_pos = seq_len - gather_len
+
+ for i in range(gather_len):
+ pos = start_pos + i
+ pos_in_block = pos % block_size
+ block_idx = block_table[req_id, pos // block_size].item()
+ cache_block = k_cache[block_idx].view(-1)
+
+ token_data_start = pos_in_block * token_data_size
+ fp8_bytes = cache_block[token_data_start : token_data_start + fp8_dim]
+ fp8_vals = fp8_bytes.view(torch.float8_e4m3fn).float()
+
+ scale_start = block_size * token_data_size + pos_in_block * scale_dim
+ encoded_scales = cache_block[scale_start : scale_start + scale_dim]
+ scales = torch.exp2(encoded_scales[:7].float() - 127.0)
+ dequant = fp8_vals * scales.repeat_interleave(quant_block)
+
+ bf16_start = token_data_start + fp8_dim
+ bf16_bytes = cache_block[bf16_start : bf16_start + bf16_dim * 2]
+ bf16_tail = bf16_bytes.view(torch.bfloat16)
+
+ out[req_id, offset + i, :fp8_dim] = dequant
+ out[req_id, offset + i, fp8_dim:] = bf16_tail
+
+
+@pytest.mark.parametrize(
+ ("seq_lens_host", "gather_lens_host", "offset"),
+ [
+ ([9, 23, 7], None, 0),
+ ([19, 8, 257], [6, 8, 129], 5),
+ ],
+)
+def test_dequantize_and_gather_k_cache(
+ seq_lens_host: list[int],
+ gather_lens_host: list[int] | None,
+ offset: int,
+):
+ block_size = 64
+ head_dim = 512
+ nope_dim = 448
+ scale_dim = 8
+ head_bytes = nope_dim + (head_dim - nope_dim) * 2 + scale_dim
+ device = "cuda"
+ num_reqs = len(seq_lens_host)
+ num_tokens = sum(seq_lens_host)
+ max_gather_len = max(gather_lens_host or seq_lens_host)
+ max_blocks_per_seq = math.ceil(max(seq_lens_host) / block_size)
+ num_blocks = sum(math.ceil(seq_len / block_size) for seq_len in seq_lens_host)
+
+ compressed_kv = torch.randn(
+ num_tokens, head_dim, dtype=torch.bfloat16, device=device
+ )
+
+ # Randomize physical pages so the test covers block-table translation.
+ # Keep padded block-table entries invalid to catch accidental reads.
+ physical_blocks = torch.randperm(num_blocks, device=device)
+ block_table = torch.full(
+ (num_reqs, max_blocks_per_seq), int(-1e6), dtype=torch.int32, device=device
+ )
+ start = 0
+ for req_id, seq_len in enumerate(seq_lens_host):
+ num_req_blocks = math.ceil(seq_len / block_size)
+ req_blocks = physical_blocks[start : start + num_req_blocks]
+ block_table[req_id, :num_req_blocks] = req_blocks
+ start += num_req_blocks
+
+ # Build slot_mapping for quantize_and_insert_k_cache.
+ slot_mapping = torch.empty(num_tokens, dtype=torch.int64, device=device)
+ start = 0
+ for req_id, seq_len in enumerate(seq_lens_host):
+ logical_pos = torch.arange(seq_len, dtype=torch.int64, device=device)
+ block_idx = block_table[req_id, logical_pos // block_size].to(torch.int64)
+ token_slots = block_idx * block_size + logical_pos % block_size
+ slot_mapping[start : start + seq_len] = token_slots
+ start += seq_len
+
+ # Insert compressed K into the paged cache layout used by the gather op.
+ k_cache = torch.empty(
+ num_blocks, block_size, head_bytes, dtype=torch.uint8, device=device
+ )
+ k_cache_2d = k_cache.view(num_blocks, -1)
+ quantize_and_insert_k_cache(compressed_kv, k_cache_2d, slot_mapping, block_size)
+
+ out_shape = (num_reqs, offset + max_gather_len + 3, head_dim)
+ ref_out = torch.empty(out_shape, dtype=torch.bfloat16, device=device)
+ actual_out = torch.empty_like(ref_out)
+ seq_lens = torch.tensor(seq_lens_host, dtype=torch.int32, device=device)
+ gather_lens = (
+ torch.tensor(gather_lens_host, dtype=torch.int32, device=device)
+ if gather_lens_host is not None
+ else None
+ )
+
+ # Compare production gather against a PyTorch reference for valid output rows.
+ _dequantize_and_gather_k_cache_reference(
+ ref_out, k_cache, seq_lens, gather_lens, block_table, block_size, offset
+ )
+ dequantize_and_gather_k_cache(
+ actual_out, k_cache, seq_lens, gather_lens, block_table, block_size, offset
+ )
+ torch.accelerator.synchronize()
+
+ # only check non-padded content
+ for req_id, seq_len in enumerate(seq_lens_host):
+ gather_len = (
+ gather_lens_host[req_id] if gather_lens_host is not None else seq_len
+ )
+ actual = actual_out[req_id, offset : offset + gather_len]
+ expected = ref_out[req_id, offset : offset + gather_len]
+ torch.testing.assert_close(actual, expected, rtol=0, atol=0)
+
+
-# ── Test C: Indexer path ────────────────────────────────────────────────────
+# ── Test C: Indexer path ────────────────────────────────────────────────────
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 17])
After (새로운 테스트 구조):
@@ -254,7 +388,7 @@ def test_indexer_gather_accepts_upper_bound_output():
assert torch.all(dst_scale[valid_tokens:] == sentinel)
-# ── Test C: DeepseekV4 attention with values at different magnitudes ───────────
+# ── Test D: DeepseekV4 attention with values at different magnitudes ───────────
def test_deepseek_v4_quant_magnitude_range():
@@ -316,7 +450,7 @@ def test_deepseek_v4_quant_magnitude_range():
)
-# ── Test D: Indexer fused K-cache insert (Triton kernels) ────────────────────
+# ── Test E: Indexer fused K-cache insert (Triton kernels) ────────────────────
#
# Both kernels share the same Triton signature; use_fp4 selects between them.
# Full pipeline: state-cache gather → softmax-weighted compress → RMSNorm →
2. vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py
이 파일에서는 dequantize_and_gather_k_cache 함수의 구현이 변경되었습니다. 기존의 Triton 기반 구현은 dequantize_and_gather_k_cache_triton으로 이름이 변경되었고, 새로운 CuteDSL 기반 구현을 호출하는 dequantize_and_gather_k_cache 함수가 추가되었습니다. has_cutedsl() 함수를 통해 CuteDSL 사용 가능 여부를 확인하고, 사용 가능하다면 새로운 커널을 사용하고, 그렇지 않다면 기존 Triton 커널을 사용하도록 로직이 변경되었습니다.
Before:
@@ -17,6 +17,7 @@
import torch
+from vllm.utils.import_utils import has_cutedsl
@triton.jit
@@ -303,7 +304,7 @@ def _dequantize_and_gather_k_kernel(
tl.store(output_row_ptr + bf16_output_offset + chunk_offsets, bf16_vals)
-def dequantize_and_gather_k_cache(
+def dequantize_and_gather_k_cache_triton(
# [num_reqs, max_num_tokens, head_size]
out: torch.Tensor,
# [num_blocks, block_size, head_bytes]
@@ -349,6 +350,34 @@ def dequantize_and_gather_k_cache(
)
+def dequantize_and_gather_k_cache(
+ # [num_reqs, max_num_tokens, head_size]
+ out: torch.Tensor,
+ # [num_blocks, block_size, head_bytes]
+ k_cache: torch.Tensor,
+ # [num_reqs]
+ seq_lens: torch.Tensor,
+ # [num_reqs]
+ gather_lens: torch.Tensor | None,
+ # [num_reqs, max_blocks_per_seq]
+ block_table: torch.Tensor,
+ block_size: int,
+ offset: int,
+) -> None:
+ if has_cutedsl():
+ # lazily import, otherwise some tests fail due to CUDA driver init failure.
+ from .dequant_gather_k_cutedsl import dequantize_and_gather_k_cache_cutedsl
+
+ dequantize_and_gather_k_cache_cutedsl(
+ out, k_cache, seq_lens, gather_lens, block_table, block_size, offset
+ )
+ return
+
+ dequantize_and_gather_k_cache_triton(
+ out, k_cache, seq_lens, gather_lens, block_table, block_size, offset
+ )
+
+
def compute_global_topk_indices_and_lens(
topk_indices: torch.Tensor,
token_to_req_indices: torch.Tensor,
3. vllm/v1/attention/ops/deepseek_v4_ops/cutedsl_utils.py
이 파일은 새로 추가된 것으로, CuteDSL을 사용하여 GPU 커널을 작성하기 위한 유틸리티 함수들을 포함하고 있습니다. _recast_val, _fp32x2_to_bf16x2, _bf16x2_to_fp32, _bf16x2_abs, _bf16x2_max, _bf16x2_mul 등 다양한 저수준 연산들을 정의하여, 효율적인 데이터 타입 변환 및 연산을 가능하게 합니다. 이 파일은 dequantize_and_gather_k_cache_cutedsl 함수 구현의 기반이 됩니다.
New File:
@@ -0,0 +1,145 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import cutlass
+import cutlass.cute as cute
+from cutlass import Float32, Uint32
+from cutlass._mlir import ir
+from cutlass._mlir.dialects import llvm, vector
+from cutlass.cutlass_dsl import T, dsl_user_op
+
+
+@dsl_user_op
+def _recast_val(x, dtype, *, loc=None, ip=None):
+ return dtype(llvm.bitcast(dtype.mlir_type, x.ir_value(loc=loc, ip=ip)))
+
+
+@dsl_user_op
+def _fp32x2_to_bf16x2(a: Float32, b: Float32, *, loc=None, ip=None) -> Uint32:
+ out = llvm.inline_asm(
+ T.i32(),
+ [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)],
+ "cvt.rn.bf16x2.f32 $0, $2, $1;",
+ "=r,f,f",
+ has_side_effects=False,
+ is_align_stack=False,
+ )
+ return Uint32(out)
+
+
+@dsl_user_op
+def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> tuple[Float32, Float32]:
+ out = llvm.inline_asm(
+ llvm.StructType.get_literal([T.f32(), T.f32()]),
+ [data.ir_value(loc=loc, ip=ip)],
+ "shl.b32 $0, $2, 16;\n\tand.b32 $1, $2, 0xFFFF0000;\n",
+ "=f,=f,r",
+ has_side_effects=False,
+ is_align_stack=False,
+ )
+ return (
+ Float32(llvm.extractvalue(T.f32(), out, [0], loc=loc, ip=ip)),
+ Float32(llvm.extractvalue(T.f32(), out, [1], loc=loc, ip=ip)),
+ )
+
+
+@dsl_user_op
+def _bf16x2_abs(a: Uint32, *, loc=None, ip=None) -> Uint32:
+ out = llvm.inline_asm(
+ T.i32(),
+ [a.ir_value(loc=loc, ip=ip)],
+ "abs.bf16x2 $0, $1;",
+ "=r,r",
+ has_side_effects=False,
+ is_align_stack=False,
+ )
+ return Uint32(out)
+
+
+@dsl_user_op
+def _bf16x2_max(a: Uint32, b: Uint32, *, loc=None, ip=None) -> Uint32:
+ out = llvm.inline_asm(
+ T.i32(),
+ [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)],
+ "max.bf16x2 $0, $1, $2;",
+ "=r,r,r",
+ has_side_effects=False,
+ is_align_stack=False,
+ )
+ return Uint32(out)
+
+
+@dsl_user_op
+def _bf16x2_mul(a: Uint32, b: Uint32, *, loc=None, ip=None) -> Uint32:
+ out = llvm.inline_asm(
+ T.i32(),
+ [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)],
+ "mul.rn.bf16x2 $0, $1, $2;",
+ "=r,r,r",
+ has_side_effects=False,
+ is_align_stack=False,
+ )
+ return Uint32(out)
+
+
+@dsl_user_op
+def _fp8x4_to_bf16x4(x: Uint32, *, loc=None, ip=None) -> cute.TensorSS
+```
### 4. `vllm/v1/attention/ops/deepseek_v4_ops/dequant_gather_k_cutedsl.py` (추정)
PR의 diff에는 직접적으로 보이지 않지만, `cache_utils.py`에서 `from .dequant_gather_k_cutedsl import dequantize_and_gather_k_cache_cutedsl`와 같이 임포트되는 것을 볼 때, 이 파일에 CuteDSL을 사용한 실제 `dequantize_and_gather_k_cache` 구현이 존재할 것으로 추정됩니다. 이 커널은 DeepSeek-V4의 FP8 및 BF16 데이터 형식을 처리하며, 메모리 접근 패턴을 최적화하여 대역폭 활용도를 높이는 데 중점을 둘 것입니다.
## 왜 이게 좋은가?
### 성능 향상
PR 설명에 포함된 마이크로벤치마크 결과는 이 변경이 상당한 성능 향상을 가져왔음을 명확히 보여줍니다.
**GB200에서의 성능 비교 (Microbenchmarks):**
* **Compressed full gather (offset=0):**
* `k_len=512`에서 Triton은 280.17 GB/s, CuteDSL은 47.65 GB/s를 기록했습니다. (이 부분은 PR 설명의 "it can only reach ~60GB/s for a single request"와 상반되는 결과로 보이며, 아마도 `cutedsl_gbps` 계산 방식이나 `moved_bytes` 계산에 미묘한 차이가 있을 수 있습니다. 하지만 `speedup` 지표가 5.88배로 매우 높습니다.)
* `k_len=8192`에서 Triton은 61.11 GB/s, CuteDSL은 2425.98 GB/s를 기록하며 **약 39.7배**의 속도 향상을 보였습니다.
* `k_len=32000`에서는 Triton 61.75 GB/s 대비 CuteDSL 4261.90 GB/s로 **약 69배**의 속도 향상을 기록했습니다.
* `k_len=262144`에서는 Triton 62.56 GB/s 대비 CuteDSL 5947.41 GB/s로 **약 95배**의 속도 향상을 기록했습니다.
* **Uncompressed swa tail (offset=262144):**
* `seq/gather=8192/512`에서 Triton 47.21 GB/s 대비 CuteDSL 268.50 GB/s로 **약 5.69배**의 속도 향상을 보였습니다.
* `seq/gather=1048576/1024`에서 Triton 52.34 GB/s 대비 CuteDSL 505.41 GB/s로 **약 9.66배**의 속도 향상을 보였습니다.
**Prefill Benchmark (DSv4-Flash, 4x GB200):**
* **8k-1, concurrency 256:**
* TPGS (Tokens Per GPU Second)가 35732에서 37390으로 **+4.6%** 증가했습니다.
* TTFT (Time To First Token)가 14.65s에서 13.68s로 **감소**했습니다.
* **128k-1, concurrency 256:**
* TPGS가 30502에서 32513으로 **+6.6%** 증가했습니다.
* TTFT가 267.68s에서 250.58s로 **감소**했습니다.
이러한 결과는 CuteDSL을 사용한 새로운 커널이 기존 Triton 구현보다 훨씬 높은 메모리 대역폭 활용도를 달성하며, 결과적으로 추론 성능 향상으로 이어진다는 것을 보여줍니다.
### 일반적인 교훈
1. **메모리 대역폭 최적화의 중요성:** LLM 추론에서 메모리 I/O는 병목 현상의 주요 원인입니다. 커널 수준에서의 메모리 접근 패턴 최적화는 성능 향상에 결정적인 역할을 합니다.
2. **새로운 도구의 도입:** Triton은 이미 강력한 GPU 프로그래밍 도구이지만, 특정 연산이나 하드웨어 아키텍처에 더 적합한 도구가 존재할 수 있습니다. CuteDSL과 같은 DSL(Domain-Specific Language)은 특정 연산에 대해 더 높은 수준의 추상화와 최적화를 제공할 수 있습니다.
3. **점진적 도입 및 조건부 사용:** 새로운 기술(CuteDSL)을 도입할 때, 기존 코드와의 호환성을 유지하고 점진적으로 적용하는 것이 중요합니다. `has_cutedsl()`을 통한 조건부 로직은 이러한 접근 방식을 잘 보여줍니다. 이를 통해 하드웨어 호환성 문제를 관리하고, 최신 하드웨어에서는 최적의 성능을, 그렇지 않은 환경에서는 안정적인 성능을 제공할 수 있습니다.
4. **철저한 테스트:** 새로운 커널 구현은 정확성(accuracy)과 성능(performance) 모두에서 철저하게 검증되어야 합니다. 이 PR에서는 마이크로벤치마크와 함께 정확성 테스트가 포함되어 신뢰성을 높였습니다.
## 리뷰 피드백 반영
리뷰 과정에서 `zyongye`님이 Unit Test의 필요성을 제기했고, `gau-nernst`님이 이에 대한 답변으로 Codex를 활용하여 테스트 케이스를 추가했음을 밝혔습니다. 이는 코드 변경의 정확성을 보장하는 데 중요한 단계였습니다. 또한, `gau-nernst`님은 새로운 커널이 "pre-hopper" 아키텍처에서는 작동하지 않을 수 있다는 점을 언급했습니다. 이는 새로운 기술 도입 시 발생할 수 있는 하드웨어 호환성 문제를 시사하며, `has_cutedsl()`과 같은 조건부 로직의 중요성을 다시 한번 강조합니다.
## 결론
이번 vLLM의 PR은 DeepSeek-V4 모델의 K 캐시 처리 성능을 획기적으로 개선했습니다. CuteDSL이라는 새로운 도구를 성공적으로 도입하여 기존 Triton 커널의 메모리 대역폭 활용도 한계를 극복하고, 실제 추론 성능 향상으로 이어졌습니다. 이는 LLM 서빙 성능 최적화에 있어 커널 수준의 최적화와 새로운 기술 도입의 중요성을 보여주는 좋은 사례입니다. 앞으로도 vLLM은 지속적인 성능 개선을 통해 더욱 효율적인 LLM 추론 환경을 제공할 것으로 기대됩니다.
## 참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py
- https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/ops/deepseek_v4_ops/cutedsl_utils.py
- https://github.com/vllm-project/vllm/blob/main/vllm/utils/import_utils.py
> ⚠️ **알림:** 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [sglang] SGLang의 Breakable CUDA Graph 최적화: 배치 사이즈 제한 극복하기
- 현재글 : [vllm] vLLM, DeepSeek-V4 K 캐시 커널 최적화: CuteDSL 도입으로 성능 향상
- 다음글 [sglang] SGLang NPU 최적화: MoE 모델을 위한 Dual Stream 병렬 처리 도입
댓글