본문으로 건너뛰기

[sglang] DeepSeek V4의 Prefill 성능을 1.35배 향상시킨 FlashAttention 최적화

PR 링크: sgl-project/sglang#25418 상태: Merged | 변경: +959 / -5

들어가며

대규모 언어 모델(LLM)의 추론 성능은 특히 긴 시퀀스를 처리할 때 병목 현상을 겪기 쉽습니다. LLM의 추론 과정 중 'Prefill' 단계는 새로운 시퀀스의 초기 토큰들을 처리하는 부분으로, 이 단계의 효율성은 전체 응답 생성 속도에 지대한 영향을 미칩니다. SGLang 프로젝트의 이번 Pull Request(PR)는 DeepSeek V4 모델의 Prefill 단계 성능을 획기적으로 개선하는 것을 목표로 합니다. 기존의 flash_mla_with_kvcache 커널이 가진 복잡한 로딩 로직으로 인한 성능 저하 문제를 해결하고, 더 효율적인 flash_mla_sparse_fwd 커널을 도입하여 Prefill 성능을 최대 1.35배까지 향상시켰습니다.

이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 왜 이러한 변경이 성능 향상으로 이어졌는지, 그리고 이 최적화가 가지는 일반적인 교훈은 무엇인지 살펴보겠습니다.

코드 분석

이번 PR의 핵심은 Prefill 단계에서 기존의 flash_mla_with_kvcache 대신 flash_mla_sparse_fwd 커널을 사용하도록 전환하는 것입니다. 또한, 긴 입력 시퀀스에서 발생하는 버그를 회피하기 위해 flash_mla_sparse_fwd 커널로 라우팅하는 로직도 추가되었습니다.

1. 환경 변수 설정 (sglang/srt/environ.py)

새로운 최적화를 활성화하기 위한 환경 변수가 추가되었습니다.

--- a/python/sglang/srt/environ.py
+++ b/python/sglang/srt/environ.py
@@ -688,6 +688,7 @@ class Envs:
     SGLANG_OPT_USE_COMPRESSOR_V2 = EnvBool(True)
     SGLANG_FP8_PAGED_MQA_LOGITS_TORCH = EnvBool(False)
     SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(False)
+    SGLANG_OPT_FLASHMLA_SPARSE_PREFILL = EnvBool(False)
 
     # SWA radix cache
     SGLANG_OPT_CACHE_SWA_TRANSLATION = EnvBool(True)

SGLANG_OPT_FLASHMLA_SPARSE_PREFILL 환경 변수가 추가되어, 이 값이 True로 설정될 경우 새로운 flash_mla_sparse_fwd 기반의 Prefill 로직이 활성화됩니다. 이는 실험적인 기능을 쉽게 켜고 끌 수 있게 하여 개발 및 테스트 과정에서 유용합니다.

2. DeepSeek V4 Attention 백엔드 수정 (sglang/srt/layers/attention/deepseek_v4_backend.py)

가장 핵심적인 변경이 이루어진 부분입니다. Prefill 로직을 flash_mla_sparse_fwd로 전환하고 관련 데이터 구조를 관리합니다.

2.1. DSV4AttnMetadata 클래스 수정

flash_mla_sparse_fwd 커널에 필요한 새로운 필드들이 추가되었습니다.

--- a/python/sglang/srt/layers/attention/deepseek_v4_backend.py
+++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py
@@ -47,6 +51,9 @@
 from sglang.srt.layers.attention.dsv4.quant_k_cache import (
     quant_to_nope_fp8_rope_bf16_pack_triton,
 )
+from sglang.srt.layers.attention.dsv4.sparse_prefill_utils import (
+    SparsePrefillChunkCache,
+)
 from sglang.srt.layers.dp_attention import (
     get_attention_cp_rank,
     get_attention_cp_size,
@@ -114,6 +121,7 @@ class DSV4AttnMetadata:
     c4_topk_lengths_clamp1: Optional[torch.Tensor] = None
     c4_sparse_topk_lengths: torch.Tensor = field(init=False)
     c4_sparse_page_indices: torch.Tensor = field(init=False)
+    c4_sparse_raw_indices: Optional[torch.Tensor] = field(init=False, default=None)
 
     c128_out_loc: Optional[torch.Tensor] = None
     c128_page_indices: Optional[torch.Tensor] = None
@@ -161,6 +169,7 @@ def copy_(self, other: DSV4AttnMetadata) -> None:
                 "c4_topk_lengths_clamp1",
                 "c4_sparse_topk_lengths",
                 "c4_sparse_page_indices",
+                "c4_sparse_raw_indices",
             ],
             assign_fields=[
                 "c1_flashmla_metadata",
@@ -245,7 +254,7 @@ def apply_cp_reindex(self) -> None:
                 f"!= pre_global_len={pre_global_len} (must remain global for compressor write path)"
             )
 
-    def init_flashmla_related(self):
+    def init_flashmla_related(self, is_prefill: bool = False):
         # c4_sparse_topk is set from model_config.index_topk per-model
         # (small model: 512, large model: 1024).
         assert self.c4_sparse_topk in (512, 1024), (
@@ -263,6 +272,8 @@ def init_flashmla_related(self):
             device=self.c4_topk_lengths_clamp1.device,
         )
         self.c4_sparse_page_indices = _pad_last_dim(self.c4_sparse_page_indices)
+        if is_prefill:
+            self.c4_sparse_raw_indices = torch.empty_like(self.c4_sparse_page_indices)
         self.c1_flashmla_metadata = _create_flashmla_metadata()
         self.c4_flashmla_metadata = _create_flashmla_metadata()
         self.c128_flashmla_metadata = _create_flashmla_metadata()
  • c4_sparse_raw_indices: flash_mla_sparse_fwd 커널에서 사용할 원시 인덱스를 저장하기 위한 필드입니다. Prefill 시에만 필요하며, is_prefill=True일 때 초기화됩니다.
  • init_flashmla_related 함수에 is_prefill 인자가 추가되어, Prefill 시에만 c4_sparse_raw_indices를 할당하도록 변경되었습니다.

2.2. DSV4Metadata 클래스 수정

Prefill 과정에서 상태를 유지하기 위한 새로운 캐시 객체가 추가되었습니다.

--- a/python/sglang/srt/layers/attention/deepseek_v4_backend.py
+++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py
@@ -287,6 +303,11 @@ def copy_(self, other: DSV4Metadata):
         maybe_copy_inplace(
             self.c128_compress_metadata, src=other.c128_compress_metadata
         )
+        self.sparse_prefill_cache = None
+
+@dataclass
+class SparsePrefillChunkCache:
+    pass
 
 
 @dataclass
@@ -1027,6 +1044,20 @@ def forward(
                     extra_indices.shape[-1] % 64 == 0
                 ), f"{extra_indices.shape=}'s last dimension is not aligned to 64"
 
+            if forward_batch.forward_mode.is_extend_without_speculative() and (
+                q.shape[0] > _LARGE_INDEXER_QUERY_THRESHOLD
+                or envs.SGLANG_OPT_FLASHMLA_SPARSE_PREFILL.get()
+            ):
+                return self._forward_prefill_sparse(
+                    q=q,
+                    layer_id=layer_id,
+                    compress_ratio=compress_ratio,
+                    forward_batch=forward_batch,
+                    token_to_kv_pool=token_to_kv_pool,
+                    core_attn_metadata=core_attn_metadata,
+                    attn_sink=attn_sink,
+                )
+
             if _is_sm120:
                 from sglang.srt.layers.attention.flash_mla_sm120 import (
                     flash_mla_with_kvcache_sm120,
@@ -1069,6 +1100,107 @@ def forward(
 
         raise NotImplementedError("ragged attention")
 
+    def _forward_prefill_sparse(
+        self,
+        q: torch.Tensor,
+        layer_id: int,
+        compress_ratio: Literal[0, 4, 128],
+        forward_batch: ForwardBatch,
+        token_to_kv_pool: DeepSeekV4TokenToKVPool,
+        core_attn_metadata: DSV4AttnMetadata,
+        attn_sink: torch.Tensor,
+    ) -> torch.Tensor:
+        """Unified prefill via flash_mla_sparse_fwd. Replaces the
+        flash_mla_with_kvcache call on the extend path. Per request, 
+        positionally gathers the SWA window (always) and the compressed
+        cache (c4/c128) into a flat bf16 workspace, then lets
+        flash_mla_sparse_fwd consume the workspace via per-query rebased
+        indices. Chunk-invariant scaffolding lives in
+        ``self.forward_metadata.sparse_prefill_cache``.
+        """
+        from sgl_kernel.flash_mla import flash_mla_sparse_fwd
+
+        # q is (b, 1, h_q, d_qk); flash_mla_sparse_fwd takes (s_q, h_q, d_qk).
+        q_flat = q.squeeze(1)
+
+        cache = self.forward_metadata.sparse_prefill_cache
+        if cache is None:
+            # ``swa_window_size`` on the pool is its storage page size, not
+            # the model's SWA window — pass both explicitly.
+            cache = SparsePrefillChunkCache.build(
+                seq_lens=forward_batch.seq_lens.to(torch.int32),
+                extend_seq_lens=forward_batch.extend_seq_lens.to(torch.int32),
+                req_pool_indices=forward_batch.req_pool_indices.to(torch.int32),
+                req_to_token=self.req_to_token,
+                full_to_swa=token_to_kv_pool.full_to_swa_index_mapping,
+                swa_window_size=SWA_WINDOW,
+                swa_page_size=token_to_kv_pool.swa_window_size,
+                num_qo_tokens=q_flat.shape[0],
+            )
+            self.forward_metadata.sparse_prefill_cache = cache
+
+        # Resolve the workspace + indices for this ratio, then dequant
+        # SWA + compressed regions directly into the workspace (no torch.cat).
+        compressed_slice = None
+        extra_k_cache = None
+        extra_page_size = None
+        flat_token_ids = None
+        if compress_ratio == 0:
+            workspace = cache.c0_workspace
+            combined_indices = cache.c0_combined_indices
+            combined_lens = cache.c0_combined_lens
+            swa_slice = workspace
+        else:
+            extra_page_size = token_to_kv_pool.get_extra_key_page_size(layer_id)
+            extra_k_cache = token_to_kv_pool.get_extra_key_buffer(layer_id)
+            if compress_ratio == 128:
+                assert core_attn_metadata.c128_page_indices is not None
+                cache.ensure_c128(core_attn_metadata.c128_page_indices)
+                flat_token_ids = cache.c128_flat_token_ids
+                workspace = cache.c128_workspace
+                combined_indices = cache.c128_combined_indices
+                combined_lens = cache.c128_combined_lens
+            else:
+                assert core_attn_metadata.c4_sparse_raw_indices is not None, (
+                    "sparse-prefill c4 path requires c4_sparse_raw_indices "
+                    "(allocated in init_flashmla_related when is_prefill=True)"
+                )
+                cache.ensure_c4(core_attn_metadata.page_table, extra_page_size)
+                flat_token_ids = cache.c4_flat_token_ids
+                workspace = cache.c4_workspace
+                combined_indices, combined_lens = cache.combine_c4_layer(
+                    c4_sparse_raw_indices=core_attn_metadata.c4_sparse_raw_indices,
+                )
+            n_compressed = flat_token_ids.shape[0]
+            compressed_slice = workspace[:n_compressed]
+            swa_slice = workspace[n_compressed:]
+
+        if compressed_slice is not None:
+            dequantize_k_cache_paged(
+                extra_k_cache,
+                flat_token_ids,
+                page_size=extra_page_size,
+                out=compressed_slice,
+            )
+        dequantize_k_cache_paged(
+            token_to_kv_pool.get_swa_key_buffer_radix(layer_id),
+            cache.swa_token_ids,
+            page_size=cache.swa_page_size,
+            out=swa_slice,
+        )
+        kv = workspace
+
+        o, _, _ = flash_mla_sparse_fwd(
+            q=q_flat,
+            kv=kv,
+            indices=combined_indices.unsqueeze(1),
+            sm_scale=self.softmax_scale,
+            d_v=self.head_dim_v,
+            attn_sink=attn_sink,
+            topk_length=combined_lens,
+        )
+        return o
+
     def expand_prefill_casually(
         self,
         num_tokens: int,
@@ -1164,10 +1296,11 @@ def make_core_attn_metadata(
 
         if need_compress:
             core_attn_metadata.init_compression_metadata()
-            core_attn_metadata.init_flashmla_related()
+            core_attn_metadata.init_flashmla_related(is_prefill=is_prefill)
         else:
             core_attn_metadata.c4_sparse_topk_lengths = None
             core_attn_metadata.c4_sparse_page_indices = None
+            core_attn_metadata.c4_sparse_raw_indices = None
             core_attn_metadata.c1_flashmla_metadata = _create_flashmla_metadata()
             core_attn_metadata.c4_flashmla_metadata = None
             core_attn_metadata.c128_flashmla_metadata = None
  • DSV4Metadata 클래스에 sparse_prefill_cache 필드가 추가되었습니다. 이 캐시는 Prefill 과정에서 재사용될 SparsePrefillChunkCache 객체를 저장합니다. copy_ 메소드에서 None으로 초기화하여 CUDA Graph 재생성 시 새로운 캐시를 사용하도록 합니다.
  • forward 메소드에서 SGLANG_OPT_FLASHMLA_SPARSE_PREFILL 환경 변수가 활성화되었거나, 입력 시퀀스 길이가 _LARGE_INDEXER_QUERY_THRESHOLD보다 클 경우, _forward_prefill_sparse 메소드를 호출하도록 변경되었습니다.
  • _forward_prefill_sparse 메소드가 새로 추가되었습니다. 이 메소드는 flash_mla_sparse_fwd 커널을 호출하는 핵심 로직을 담고 있습니다. SWA(Sliding Window Attention) 및 압축된 캐시 데이터를 효율적으로 준비하고 커널에 전달합니다.
  • make_core_attn_metadata 함수에서 init_flashmla_related 호출 시 is_prefill 인자를 전달하도록 수정되었습니다.

2.3. dequant_k_cache_paged 함수 추가 (sglang/srt/layers/attention/dsv4/dequant_k_cache.py)

KV 캐시를 양자화된 형식에서 실제 연산에 사용될 형식으로 복원하는 새로운 함수가 추가되었습니다.

--- /dev/null
+++ b/python/sglang/srt/layers/attention/dsv4/dequant_k_cache.py
@@ -0,0 +1,226 @@
+from typing import Optional
+
+import torch
+import triton
+import triton.language as tl
+
+from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
+
+fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn
+
+# v4 KV cache layout (see dsv4.index_buf_accessor._set_k_and_s_triton_kernel):
+#   per-token: 448 fp8 nope + 64 bf16 rope (= 576 contiguous bytes) +
+#              7 ue8m0 scales padded to 8 bytes.
+#   per-page:  [token 0..P-1 nope+rope (P*576 bytes)] [token 0..P-1 scale (P*8 bytes)]
+#              padded up to a multiple of 576.
+DIM_NOPE = 448
+DIM_ROPE = 64
+TILE_SIZE = 64  # one nope scale tile = 64 fp8 values
+NUM_SCALE_TILES = DIM_NOPE // TILE_SIZE  # 7
+NOPE_ROPE_BYTES = DIM_NOPE + DIM_ROPE * 2  # 576
+PADDED_SCALE_PER_TOKEN = NUM_SCALE_TILES + 1  # 8
+
+
def dequantize_k_cache_paged(
+    quant_k_cache: torch.Tensor,
+    page_table_1_flattened: torch.Tensor,
+    page_size: int,
+    out: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+    """Dequantize the DeepSeek v4 paged KV cache for a list of token IDs.
+
+    Args:
+        quant_k_cache: (num_pages, bytes_per_page_padded) uint8.
+        page_table_1_flattened: (num_tokens,) int — token IDs into the cache.
+        page_size: number of tokens per page.
+        out: op
+

이 함수는 flash_mla_sparse_fwd 커널이 필요로 하는 KV 캐시 데이터를 준비하는 데 사용됩니다. 특히 FP8 형식으로 양자화된 키(K) 캐시를 복원하는 역할을 합니다.

3. Sparse Prefill 유틸리티 추가 (sglang/srt/layers/attention/dsv4/sparse_prefill_utils.py)

flash_mla_sparse_fwd 커널과 함께 사용될 SparsePrefillChunkCache 클래스가 추가되었습니다. 이 클래스는 Prefill 과정에서 필요한 다양한 인덱스와 데이터를 효율적으로 관리합니다.

(참고: sparse_prefill_utils.py 파일의 전체 diff는 제공되지 않았으나, DSV4Metadata 클래스에서 이 클래스의 build 메소드를 사용하는 것으로 보아 Prefill 관련 상태 관리를 담당하는 핵심 컴포넌트임을 알 수 있습니다.)

왜 이게 좋은가?

이번 PR의 핵심은 flash_mla_sparse_fwd 커널의 도입과 이를 활용하기 위한 데이터 준비 로직의 개선입니다. 이 변경이 성능 향상으로 이어지는 이유는 다음과 같습니다.

  1. 효율적인 데이터 로딩: 기존 flash_mla_with_kvcache 커널은 복잡한 로딩 로직을 가지고 있어 성능 저하의 원인이 되었습니다. 반면, flash_mla_sparse_fwd는 TMA(Tensor Memory Accelerator) 로드를 직접 사용하여 데이터를 더 효율적으로 로드합니다. 이는 GPU 메모리에서 데이터를 가져오는 시간을 단축시켜 전체적인 연산 속도를 높입니다.
  2. 커널 통합 및 최적화: flash_mla_sparse_fwd는 Prefill 단계에 특화되어 최적화된 커널입니다. SWA(Sliding Window Attention) 및 압축된 캐시 데이터를 통합하여 단일 커널에서 처리함으로써, 여러 커널 호출 및 데이터 복사 오버헤드를 줄입니다.
  3. 성능 향상 수치: PR 설명에 따르면, flash_mla_sparse_fwd 커널을 사용했을 때 기존 flash_mla_with_kvcache 대비 Prefill 단계에서 1.35배의 속도 향상을 보였습니다. 전체 CUDA 월 타임(wall time) 기준으로는 1.1배의 속도 향상이 관찰되었습니다. 이는 매우 유의미한 성능 개선입니다.
  4. 긴 시퀀스 처리 버그 해결: PR 설명에 언급된 issues/25484는 긴 입력 시퀀스에서 발생하는 버그로, 이를 flash_mla_sparse_fwd 커널로 라우팅하여 해결했습니다. 이는 모델의 안정성과 활용성을 높이는 중요한 개선입니다.
  5. 긴 시퀀스 Prefill 지원: 기존에는 Chunk Prefill이 최대 8192 시퀀스 길이까지만 지원되었으나, 이 PR을 통해 32768과 같은 더 긴 시퀀스 길이에서도 정상적으로 작동하게 되었습니다. 이는 더 큰 컨텍스트를 가진 시나리오에서 모델의 활용도를 크게 높입니다.

일반적인 교훈

  • 커널 수준 최적화의 중요성: LLM 추론 성능의 병목은 종종 특정 연산(Attention 등)의 커널 구현에서 발생합니다. 최신 하드웨어 기능(TMA 등)을 활용하고, 데이터 로딩 및 처리 방식을 최적화하는 커널 개발은 성능 향상의 핵심입니다.
  • 데이터 준비 로직의 간소화: 복잡하고 비효율적인 데이터 준비 및 로딩 로직은 성능 저하의 주범이 될 수 있습니다. 데이터를 효율적으로 통합하고 커널에 직접 전달하는 방식은 성능을 크게 향상시킬 수 있습니다.
  • 환경 변수를 통한 기능 제어: 새로운 기능을 도입할 때 환경 변수를 사용하여 점진적으로 릴리즈하고 테스트하는 것은 안정성을 확보하는 좋은 방법입니다.

리뷰 피드백 반영

리뷰 과정에서 몇 가지 중요한 논의가 있었습니다.

  • H20 테스트 요청: @zcnrex는 B200 뿐만 아니라 H20 GPU에서도 SGLANG_OPT_FLASHMLA_SPARSE_PREFILL=1 플래그를 사용하여 테스트해달라고 요청했습니다. 이는 다양한 하드웨어 아키텍처에서의 호환성 및 성능을 검증하기 위함입니다.
  • AIME 25 벤치마크 결과: @Fridge003은 DeepSeek V4 Pro 모델에 이 변경 사항을 적용했을 때 AIME 25 벤치마크에서 98.75%의 높은 pass@1 정확도를 달성했음을 공유했습니다. 이는 기능 변경이 모델의 정확성에 부정적인 영향을 미치지 않음을 시사합니다.
  • Dequantization 범위: @DarkSharpnessc4 캐시의 모든 데이터를 dequantize하는지, 아니면 선택된 데이터만 dequantize하는지에 대한 질문을 했습니다. @Fridge003sparse_prefill_cache가 첫 번째 레이어에서 한 번만 계산되고 flat_token_ids 값이 변하지 않는다는 점을 들어 모든 c4 캐시를 dequantize하는 것으로 보인다고 추측했습니다. 이는 flash_mla_sparse_fwd가 KV 캐시 전체를 효율적으로 사용하도록 설계되었음을 의미할 수 있습니다.
  • 하드코딩 값 및 데이터 타입: @Fridge003은 코드 내 하드코딩된 값(_LARGE_INDEXER_QUERY_THRESHOLD 등)과 int32 대신 int64를 사용해야 하는 이유(IMA 방지)에 대해 질문했습니다. 이는 코드의 유연성과 안정성을 높이기 위한 중요한 지적입니다.
  • 참조 및 테스트 코드 제안: @Fridge003vllm 프로젝트의 유사한 코드를 참조로 제공하고, Triton 커널의 정확성을 검증하기 위해 PyTorch 참조 구현 및 비교 테스트 추가를 제안했습니다. 이는 코드의 신뢰성을 높이는 데 기여할 것입니다.

이러한 리뷰 피드백은 코드의 완성도를 높이고, 다양한 환경에서의 검증 및 잠재적 문제점 개선에 도움을 주었습니다.

결론

이번 PR은 DeepSeek V4 모델의 Prefill 성능을 획기적으로 개선하기 위해 flash_mla_sparse_fwd 커널을 성공적으로 통합했습니다. TMA 로드 활용, 데이터 준비 로직 간소화, 커널 최적화를 통해 1.35배의 속도 향상을 달성했으며, 긴 시퀀스 처리 능력과 안정성 또한 향상시켰습니다. 이는 LLM 추론 최적화 분야에서 커널 수준의 최적화가 얼마나 중요한지를 다시 한번 보여주는 사례입니다.

References

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글