[sglang] DeepSeek V4의 Prefill 성능을 1.35배 향상시킨 FlashAttention 최적화
PR 링크: sgl-project/sglang#25418 상태: Merged | 변경: +959 / -5
들어가며
대규모 언어 모델(LLM)의 추론 성능은 특히 긴 시퀀스를 처리할 때 병목 현상을 겪기 쉽습니다. LLM의 추론 과정 중 'Prefill' 단계는 새로운 시퀀스의 초기 토큰들을 처리하는 부분으로, 이 단계의 효율성은 전체 응답 생성 속도에 지대한 영향을 미칩니다. SGLang 프로젝트의 이번 Pull Request(PR)는 DeepSeek V4 모델의 Prefill 단계 성능을 획기적으로 개선하는 것을 목표로 합니다. 기존의 flash_mla_with_kvcache 커널이 가진 복잡한 로딩 로직으로 인한 성능 저하 문제를 해결하고, 더 효율적인 flash_mla_sparse_fwd 커널을 도입하여 Prefill 성능을 최대 1.35배까지 향상시켰습니다.
이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 왜 이러한 변경이 성능 향상으로 이어졌는지, 그리고 이 최적화가 가지는 일반적인 교훈은 무엇인지 살펴보겠습니다.
코드 분석
이번 PR의 핵심은 Prefill 단계에서 기존의 flash_mla_with_kvcache 대신 flash_mla_sparse_fwd 커널을 사용하도록 전환하는 것입니다. 또한, 긴 입력 시퀀스에서 발생하는 버그를 회피하기 위해 flash_mla_sparse_fwd 커널로 라우팅하는 로직도 추가되었습니다.
1. 환경 변수 설정 (sglang/srt/environ.py)
새로운 최적화를 활성화하기 위한 환경 변수가 추가되었습니다.
--- a/python/sglang/srt/environ.py
+++ b/python/sglang/srt/environ.py
@@ -688,6 +688,7 @@ class Envs:
SGLANG_OPT_USE_COMPRESSOR_V2 = EnvBool(True)
SGLANG_FP8_PAGED_MQA_LOGITS_TORCH = EnvBool(False)
SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(False)
+ SGLANG_OPT_FLASHMLA_SPARSE_PREFILL = EnvBool(False)
# SWA radix cache
SGLANG_OPT_CACHE_SWA_TRANSLATION = EnvBool(True)
SGLANG_OPT_FLASHMLA_SPARSE_PREFILL 환경 변수가 추가되어, 이 값이 True로 설정될 경우 새로운 flash_mla_sparse_fwd 기반의 Prefill 로직이 활성화됩니다. 이는 실험적인 기능을 쉽게 켜고 끌 수 있게 하여 개발 및 테스트 과정에서 유용합니다.
2. DeepSeek V4 Attention 백엔드 수정 (sglang/srt/layers/attention/deepseek_v4_backend.py)
가장 핵심적인 변경이 이루어진 부분입니다. Prefill 로직을 flash_mla_sparse_fwd로 전환하고 관련 데이터 구조를 관리합니다.
2.1. DSV4AttnMetadata 클래스 수정
flash_mla_sparse_fwd 커널에 필요한 새로운 필드들이 추가되었습니다.
--- a/python/sglang/srt/layers/attention/deepseek_v4_backend.py
+++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py
@@ -47,6 +51,9 @@
from sglang.srt.layers.attention.dsv4.quant_k_cache import (
quant_to_nope_fp8_rope_bf16_pack_triton,
)
+from sglang.srt.layers.attention.dsv4.sparse_prefill_utils import (
+ SparsePrefillChunkCache,
+)
from sglang.srt.layers.dp_attention import (
get_attention_cp_rank,
get_attention_cp_size,
@@ -114,6 +121,7 @@ class DSV4AttnMetadata:
c4_topk_lengths_clamp1: Optional[torch.Tensor] = None
c4_sparse_topk_lengths: torch.Tensor = field(init=False)
c4_sparse_page_indices: torch.Tensor = field(init=False)
+ c4_sparse_raw_indices: Optional[torch.Tensor] = field(init=False, default=None)
c128_out_loc: Optional[torch.Tensor] = None
c128_page_indices: Optional[torch.Tensor] = None
@@ -161,6 +169,7 @@ def copy_(self, other: DSV4AttnMetadata) -> None:
"c4_topk_lengths_clamp1",
"c4_sparse_topk_lengths",
"c4_sparse_page_indices",
+ "c4_sparse_raw_indices",
],
assign_fields=[
"c1_flashmla_metadata",
@@ -245,7 +254,7 @@ def apply_cp_reindex(self) -> None:
f"!= pre_global_len={pre_global_len} (must remain global for compressor write path)"
)
- def init_flashmla_related(self):
+ def init_flashmla_related(self, is_prefill: bool = False):
# c4_sparse_topk is set from model_config.index_topk per-model
# (small model: 512, large model: 1024).
assert self.c4_sparse_topk in (512, 1024), (
@@ -263,6 +272,8 @@ def init_flashmla_related(self):
device=self.c4_topk_lengths_clamp1.device,
)
self.c4_sparse_page_indices = _pad_last_dim(self.c4_sparse_page_indices)
+ if is_prefill:
+ self.c4_sparse_raw_indices = torch.empty_like(self.c4_sparse_page_indices)
self.c1_flashmla_metadata = _create_flashmla_metadata()
self.c4_flashmla_metadata = _create_flashmla_metadata()
self.c128_flashmla_metadata = _create_flashmla_metadata()
c4_sparse_raw_indices:flash_mla_sparse_fwd커널에서 사용할 원시 인덱스를 저장하기 위한 필드입니다. Prefill 시에만 필요하며,is_prefill=True일 때 초기화됩니다.init_flashmla_related함수에is_prefill인자가 추가되어, Prefill 시에만c4_sparse_raw_indices를 할당하도록 변경되었습니다.
2.2. DSV4Metadata 클래스 수정
Prefill 과정에서 상태를 유지하기 위한 새로운 캐시 객체가 추가되었습니다.
--- a/python/sglang/srt/layers/attention/deepseek_v4_backend.py
+++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py
@@ -287,6 +303,11 @@ def copy_(self, other: DSV4Metadata):
maybe_copy_inplace(
self.c128_compress_metadata, src=other.c128_compress_metadata
)
+ self.sparse_prefill_cache = None
+
+@dataclass
+class SparsePrefillChunkCache:
+ pass
@dataclass
@@ -1027,6 +1044,20 @@ def forward(
extra_indices.shape[-1] % 64 == 0
), f"{extra_indices.shape=}'s last dimension is not aligned to 64"
+ if forward_batch.forward_mode.is_extend_without_speculative() and (
+ q.shape[0] > _LARGE_INDEXER_QUERY_THRESHOLD
+ or envs.SGLANG_OPT_FLASHMLA_SPARSE_PREFILL.get()
+ ):
+ return self._forward_prefill_sparse(
+ q=q,
+ layer_id=layer_id,
+ compress_ratio=compress_ratio,
+ forward_batch=forward_batch,
+ token_to_kv_pool=token_to_kv_pool,
+ core_attn_metadata=core_attn_metadata,
+ attn_sink=attn_sink,
+ )
+
if _is_sm120:
from sglang.srt.layers.attention.flash_mla_sm120 import (
flash_mla_with_kvcache_sm120,
@@ -1069,6 +1100,107 @@ def forward(
raise NotImplementedError("ragged attention")
+ def _forward_prefill_sparse(
+ self,
+ q: torch.Tensor,
+ layer_id: int,
+ compress_ratio: Literal[0, 4, 128],
+ forward_batch: ForwardBatch,
+ token_to_kv_pool: DeepSeekV4TokenToKVPool,
+ core_attn_metadata: DSV4AttnMetadata,
+ attn_sink: torch.Tensor,
+ ) -> torch.Tensor:
+ """Unified prefill via flash_mla_sparse_fwd. Replaces the
+ flash_mla_with_kvcache call on the extend path. Per request,
+ positionally gathers the SWA window (always) and the compressed
+ cache (c4/c128) into a flat bf16 workspace, then lets
+ flash_mla_sparse_fwd consume the workspace via per-query rebased
+ indices. Chunk-invariant scaffolding lives in
+ ``self.forward_metadata.sparse_prefill_cache``.
+ """
+ from sgl_kernel.flash_mla import flash_mla_sparse_fwd
+
+ # q is (b, 1, h_q, d_qk); flash_mla_sparse_fwd takes (s_q, h_q, d_qk).
+ q_flat = q.squeeze(1)
+
+ cache = self.forward_metadata.sparse_prefill_cache
+ if cache is None:
+ # ``swa_window_size`` on the pool is its storage page size, not
+ # the model's SWA window — pass both explicitly.
+ cache = SparsePrefillChunkCache.build(
+ seq_lens=forward_batch.seq_lens.to(torch.int32),
+ extend_seq_lens=forward_batch.extend_seq_lens.to(torch.int32),
+ req_pool_indices=forward_batch.req_pool_indices.to(torch.int32),
+ req_to_token=self.req_to_token,
+ full_to_swa=token_to_kv_pool.full_to_swa_index_mapping,
+ swa_window_size=SWA_WINDOW,
+ swa_page_size=token_to_kv_pool.swa_window_size,
+ num_qo_tokens=q_flat.shape[0],
+ )
+ self.forward_metadata.sparse_prefill_cache = cache
+
+ # Resolve the workspace + indices for this ratio, then dequant
+ # SWA + compressed regions directly into the workspace (no torch.cat).
+ compressed_slice = None
+ extra_k_cache = None
+ extra_page_size = None
+ flat_token_ids = None
+ if compress_ratio == 0:
+ workspace = cache.c0_workspace
+ combined_indices = cache.c0_combined_indices
+ combined_lens = cache.c0_combined_lens
+ swa_slice = workspace
+ else:
+ extra_page_size = token_to_kv_pool.get_extra_key_page_size(layer_id)
+ extra_k_cache = token_to_kv_pool.get_extra_key_buffer(layer_id)
+ if compress_ratio == 128:
+ assert core_attn_metadata.c128_page_indices is not None
+ cache.ensure_c128(core_attn_metadata.c128_page_indices)
+ flat_token_ids = cache.c128_flat_token_ids
+ workspace = cache.c128_workspace
+ combined_indices = cache.c128_combined_indices
+ combined_lens = cache.c128_combined_lens
+ else:
+ assert core_attn_metadata.c4_sparse_raw_indices is not None, (
+ "sparse-prefill c4 path requires c4_sparse_raw_indices "
+ "(allocated in init_flashmla_related when is_prefill=True)"
+ )
+ cache.ensure_c4(core_attn_metadata.page_table, extra_page_size)
+ flat_token_ids = cache.c4_flat_token_ids
+ workspace = cache.c4_workspace
+ combined_indices, combined_lens = cache.combine_c4_layer(
+ c4_sparse_raw_indices=core_attn_metadata.c4_sparse_raw_indices,
+ )
+ n_compressed = flat_token_ids.shape[0]
+ compressed_slice = workspace[:n_compressed]
+ swa_slice = workspace[n_compressed:]
+
+ if compressed_slice is not None:
+ dequantize_k_cache_paged(
+ extra_k_cache,
+ flat_token_ids,
+ page_size=extra_page_size,
+ out=compressed_slice,
+ )
+ dequantize_k_cache_paged(
+ token_to_kv_pool.get_swa_key_buffer_radix(layer_id),
+ cache.swa_token_ids,
+ page_size=cache.swa_page_size,
+ out=swa_slice,
+ )
+ kv = workspace
+
+ o, _, _ = flash_mla_sparse_fwd(
+ q=q_flat,
+ kv=kv,
+ indices=combined_indices.unsqueeze(1),
+ sm_scale=self.softmax_scale,
+ d_v=self.head_dim_v,
+ attn_sink=attn_sink,
+ topk_length=combined_lens,
+ )
+ return o
+
def expand_prefill_casually(
self,
num_tokens: int,
@@ -1164,10 +1296,11 @@ def make_core_attn_metadata(
if need_compress:
core_attn_metadata.init_compression_metadata()
- core_attn_metadata.init_flashmla_related()
+ core_attn_metadata.init_flashmla_related(is_prefill=is_prefill)
else:
core_attn_metadata.c4_sparse_topk_lengths = None
core_attn_metadata.c4_sparse_page_indices = None
+ core_attn_metadata.c4_sparse_raw_indices = None
core_attn_metadata.c1_flashmla_metadata = _create_flashmla_metadata()
core_attn_metadata.c4_flashmla_metadata = None
core_attn_metadata.c128_flashmla_metadata = None
DSV4Metadata클래스에sparse_prefill_cache필드가 추가되었습니다. 이 캐시는 Prefill 과정에서 재사용될SparsePrefillChunkCache객체를 저장합니다.copy_메소드에서None으로 초기화하여 CUDA Graph 재생성 시 새로운 캐시를 사용하도록 합니다.forward메소드에서SGLANG_OPT_FLASHMLA_SPARSE_PREFILL환경 변수가 활성화되었거나, 입력 시퀀스 길이가_LARGE_INDEXER_QUERY_THRESHOLD보다 클 경우,_forward_prefill_sparse메소드를 호출하도록 변경되었습니다._forward_prefill_sparse메소드가 새로 추가되었습니다. 이 메소드는flash_mla_sparse_fwd커널을 호출하는 핵심 로직을 담고 있습니다. SWA(Sliding Window Attention) 및 압축된 캐시 데이터를 효율적으로 준비하고 커널에 전달합니다.make_core_attn_metadata함수에서init_flashmla_related호출 시is_prefill인자를 전달하도록 수정되었습니다.
2.3. dequant_k_cache_paged 함수 추가 (sglang/srt/layers/attention/dsv4/dequant_k_cache.py)
KV 캐시를 양자화된 형식에서 실제 연산에 사용될 형식으로 복원하는 새로운 함수가 추가되었습니다.
--- /dev/null
+++ b/python/sglang/srt/layers/attention/dsv4/dequant_k_cache.py
@@ -0,0 +1,226 @@
+from typing import Optional
+
+import torch
+import triton
+import triton.language as tl
+
+from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
+
+fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn
+
+# v4 KV cache layout (see dsv4.index_buf_accessor._set_k_and_s_triton_kernel):
+# per-token: 448 fp8 nope + 64 bf16 rope (= 576 contiguous bytes) +
+# 7 ue8m0 scales padded to 8 bytes.
+# per-page: [token 0..P-1 nope+rope (P*576 bytes)] [token 0..P-1 scale (P*8 bytes)]
+# padded up to a multiple of 576.
+DIM_NOPE = 448
+DIM_ROPE = 64
+TILE_SIZE = 64 # one nope scale tile = 64 fp8 values
+NUM_SCALE_TILES = DIM_NOPE // TILE_SIZE # 7
+NOPE_ROPE_BYTES = DIM_NOPE + DIM_ROPE * 2 # 576
+PADDED_SCALE_PER_TOKEN = NUM_SCALE_TILES + 1 # 8
+
+
def dequantize_k_cache_paged(
+ quant_k_cache: torch.Tensor,
+ page_table_1_flattened: torch.Tensor,
+ page_size: int,
+ out: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ """Dequantize the DeepSeek v4 paged KV cache for a list of token IDs.
+
+ Args:
+ quant_k_cache: (num_pages, bytes_per_page_padded) uint8.
+ page_table_1_flattened: (num_tokens,) int — token IDs into the cache.
+ page_size: number of tokens per page.
+ out: op
+
이 함수는 flash_mla_sparse_fwd 커널이 필요로 하는 KV 캐시 데이터를 준비하는 데 사용됩니다. 특히 FP8 형식으로 양자화된 키(K) 캐시를 복원하는 역할을 합니다.
3. Sparse Prefill 유틸리티 추가 (sglang/srt/layers/attention/dsv4/sparse_prefill_utils.py)
flash_mla_sparse_fwd 커널과 함께 사용될 SparsePrefillChunkCache 클래스가 추가되었습니다. 이 클래스는 Prefill 과정에서 필요한 다양한 인덱스와 데이터를 효율적으로 관리합니다.
(참고: sparse_prefill_utils.py 파일의 전체 diff는 제공되지 않았으나, DSV4Metadata 클래스에서 이 클래스의 build 메소드를 사용하는 것으로 보아 Prefill 관련 상태 관리를 담당하는 핵심 컴포넌트임을 알 수 있습니다.)
왜 이게 좋은가?
이번 PR의 핵심은 flash_mla_sparse_fwd 커널의 도입과 이를 활용하기 위한 데이터 준비 로직의 개선입니다. 이 변경이 성능 향상으로 이어지는 이유는 다음과 같습니다.
- 효율적인 데이터 로딩: 기존
flash_mla_with_kvcache커널은 복잡한 로딩 로직을 가지고 있어 성능 저하의 원인이 되었습니다. 반면,flash_mla_sparse_fwd는 TMA(Tensor Memory Accelerator) 로드를 직접 사용하여 데이터를 더 효율적으로 로드합니다. 이는 GPU 메모리에서 데이터를 가져오는 시간을 단축시켜 전체적인 연산 속도를 높입니다. - 커널 통합 및 최적화:
flash_mla_sparse_fwd는 Prefill 단계에 특화되어 최적화된 커널입니다. SWA(Sliding Window Attention) 및 압축된 캐시 데이터를 통합하여 단일 커널에서 처리함으로써, 여러 커널 호출 및 데이터 복사 오버헤드를 줄입니다. - 성능 향상 수치: PR 설명에 따르면,
flash_mla_sparse_fwd커널을 사용했을 때 기존flash_mla_with_kvcache대비 Prefill 단계에서 1.35배의 속도 향상을 보였습니다. 전체 CUDA 월 타임(wall time) 기준으로는 1.1배의 속도 향상이 관찰되었습니다. 이는 매우 유의미한 성능 개선입니다. - 긴 시퀀스 처리 버그 해결: PR 설명에 언급된
issues/25484는 긴 입력 시퀀스에서 발생하는 버그로, 이를flash_mla_sparse_fwd커널로 라우팅하여 해결했습니다. 이는 모델의 안정성과 활용성을 높이는 중요한 개선입니다. - 긴 시퀀스 Prefill 지원: 기존에는 Chunk Prefill이 최대 8192 시퀀스 길이까지만 지원되었으나, 이 PR을 통해 32768과 같은 더 긴 시퀀스 길이에서도 정상적으로 작동하게 되었습니다. 이는 더 큰 컨텍스트를 가진 시나리오에서 모델의 활용도를 크게 높입니다.
일반적인 교훈
- 커널 수준 최적화의 중요성: LLM 추론 성능의 병목은 종종 특정 연산(Attention 등)의 커널 구현에서 발생합니다. 최신 하드웨어 기능(TMA 등)을 활용하고, 데이터 로딩 및 처리 방식을 최적화하는 커널 개발은 성능 향상의 핵심입니다.
- 데이터 준비 로직의 간소화: 복잡하고 비효율적인 데이터 준비 및 로딩 로직은 성능 저하의 주범이 될 수 있습니다. 데이터를 효율적으로 통합하고 커널에 직접 전달하는 방식은 성능을 크게 향상시킬 수 있습니다.
- 환경 변수를 통한 기능 제어: 새로운 기능을 도입할 때 환경 변수를 사용하여 점진적으로 릴리즈하고 테스트하는 것은 안정성을 확보하는 좋은 방법입니다.
리뷰 피드백 반영
리뷰 과정에서 몇 가지 중요한 논의가 있었습니다.
- H20 테스트 요청:
@zcnrex는 B200 뿐만 아니라 H20 GPU에서도SGLANG_OPT_FLASHMLA_SPARSE_PREFILL=1플래그를 사용하여 테스트해달라고 요청했습니다. 이는 다양한 하드웨어 아키텍처에서의 호환성 및 성능을 검증하기 위함입니다. - AIME 25 벤치마크 결과:
@Fridge003은 DeepSeek V4 Pro 모델에 이 변경 사항을 적용했을 때 AIME 25 벤치마크에서 98.75%의 높은 pass@1 정확도를 달성했음을 공유했습니다. 이는 기능 변경이 모델의 정확성에 부정적인 영향을 미치지 않음을 시사합니다. - Dequantization 범위:
@DarkSharpness는c4캐시의 모든 데이터를 dequantize하는지, 아니면 선택된 데이터만 dequantize하는지에 대한 질문을 했습니다.@Fridge003은sparse_prefill_cache가 첫 번째 레이어에서 한 번만 계산되고flat_token_ids값이 변하지 않는다는 점을 들어 모든c4캐시를 dequantize하는 것으로 보인다고 추측했습니다. 이는flash_mla_sparse_fwd가 KV 캐시 전체를 효율적으로 사용하도록 설계되었음을 의미할 수 있습니다. - 하드코딩 값 및 데이터 타입:
@Fridge003은 코드 내 하드코딩된 값(_LARGE_INDEXER_QUERY_THRESHOLD등)과int32대신int64를 사용해야 하는 이유(IMA방지)에 대해 질문했습니다. 이는 코드의 유연성과 안정성을 높이기 위한 중요한 지적입니다. - 참조 및 테스트 코드 제안:
@Fridge003은vllm프로젝트의 유사한 코드를 참조로 제공하고, Triton 커널의 정확성을 검증하기 위해 PyTorch 참조 구현 및 비교 테스트 추가를 제안했습니다. 이는 코드의 신뢰성을 높이는 데 기여할 것입니다.
이러한 리뷰 피드백은 코드의 완성도를 높이고, 다양한 환경에서의 검증 및 잠재적 문제점 개선에 도움을 주었습니다.
결론
이번 PR은 DeepSeek V4 모델의 Prefill 성능을 획기적으로 개선하기 위해 flash_mla_sparse_fwd 커널을 성공적으로 통합했습니다. TMA 로드 활용, 데이터 준비 로직 간소화, 커널 최적화를 통해 1.35배의 속도 향상을 달성했으며, 긴 시퀀스 처리 능력과 안정성 또한 향상시켰습니다. 이는 LLM 추론 최적화 분야에서 커널 수준의 최적화가 얼마나 중요한지를 다시 한번 보여주는 사례입니다.
References
- Flash Attention
- TMA (Tensor Memory Accelerator)
- DeepSeek V4 Architecture
- vLLM vllm/models/deepseek_v4/common/ops/cache_utils.py
참고 자료
- https://github.com/Dao-AILab/flash-attention
- https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tma-instructions
- https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash
- https://github.com/vllm-project/vllm/blob/124fac10cb0ea83aee2ffeabac0b413d6b759b26/vllm/models/deepseek_v4/common/ops/cache_utils.py#L476
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] [SGLang] Blackwell(B200)에서 Diffusion Attention 성능을 7배 끌어올리는 Triton 커널 최적화 분석
- [sglang] AMD ROCm 환경에서의 성능 최적화: Triton을 활용한 Fused QK GemmaRMSNorm 구현
- [sglang] SGLang: ROCm 환경에서 Qwen3-VL 디코딩 성능 극대화를 위한 커널 퓨전 최적화
- [sglang] SGLang Whisper 모델의 CUDA Graph 도입 및 성능 최적화 분석
- [sglang] SGLang의 FA3 디코드 최적화: get_scheduler_metadata 도입
PR Analysis 의 다른글
- 이전글 [feast] Feast 온라인 서빙 성능 튜닝: Sub-2ms 달성을 위한 여정
- 현재글 : [sglang] DeepSeek V4의 Prefill 성능을 1.35배 향상시킨 FlashAttention 최적화
- 다음글 [ray] Ray Data의 hash_partition 성능을 7배 향상시킨 최적화 전략
댓글