본문으로 건너뛰기

[SGLang] NSA (Narrow Sparse Attention): DeepSeek의 스파스 어텐션

들어가며

Dense attention은 모든 토큰 쌍에 대해 어텐션을 계산한다. 시퀀스 길이 N에 대해 O(N^2) 비용이 발생하며, 이는 긴 시퀀스에서 주요 병목이 된다. NSA(Narrow Sparse Attention)는 DeepSeek이 제안한 sparse attention 기법으로, 각 query에 대해 가장 중요한 K/V 토큰만 선택적으로 어텐션한다.

SGLang의 NativeSparseAttnBackend는 NSA를 위한 전용 백엔드로, 인덱서(indexer)를 통한 top-k 토큰 선택, FlashMLA/FA3/TileLang/TRT-LLM 등 다양한 sparse 커널, MLA와의 통합을 지원한다.

이 글에서는 python/sglang/srt/layers/attention/nsa_backend.pynsa/ 디렉토리를 분석한다.

전체 구조

NSA의 어텐션 흐름은 Dense Attention과 근본적으로 다르다.

  Query (각 토큰)
     │
     ▼
┌──────────────────────────────────────────────┐
│              NSA Indexer                      │
│  1. 전체 KV에서 Score 계산 (DeepGEMM)         │
│  2. Top-K 인덱스 선택 (fast_topk)             │
│  3. 인덱스 → 페이지 테이블 변환               │
└──────────────┬───────────────────────────────┘
               │  topk_indices
               ▼
┌──────────────────────────────────────────────┐
│         Sparse Attention Kernel               │
│                                              │
│  ┌─ Decode ──────────────────────────────┐   │
│  │  flashmla_sparse: FlashMLA sparse     │   │
│  │  flashmla_kv:     FlashMLA + KV       │   │
│  │  fa3:             FlashAttention v3   │   │
│  │  trtllm:          TRT-LLM kernel      │   │
│  └───────────────────────────────────────┘   │
│                                              │
│  ┌─ Prefill ─────────────────────────────┐   │
│  │  flashmla_sparse/kv: FlashMLA         │   │
│  │  fa3:             FlashAttention v3   │   │
│  │  tilelang:        TileLang kernel     │   │
│  │  trtllm:          TRT-LLM kernel      │   │
│  └───────────────────────────────────────┘   │
└──────────────────────────────────────────────┘

NativeSparseAttnBackend 초기화

class NativeSparseAttnBackend(
    NativeSparseAttnBackendMTPPrecomputeMixin, AttentionBackend
):
    def __init__(self, model_runner, skip_prefill=False, ...):
        self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
        assert self.use_nsa, "NSA backend only supports DeepSeek NSA"

        self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)
        self.nsa_kv_cache_store_fp8 = (
            model_runner.token_to_kv_pool.nsa_kv_cache_store_fp8
        )
        self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
        self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
        self.kv_lora_rank = model_runner.model_config.kv_lora_rank
        self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim

NSA 백엔드는 DeepSeek NSA 모델 전용이다. nsa_index_topk는 각 query가 어텐션할 top-K 토큰 수를 결정한다. MLA의 compressed KV cache(kv_cache_dim = kv_lora_rank + qk_rope_head_dim)를 사용한다.

커널 구현체 선택

self.nsa_prefill_impl: _NSA_IMPL_T = model_runner.server_args.nsa_prefill_backend
self.nsa_decode_impl: _NSA_IMPL_T = model_runner.server_args.nsa_decode_backend

_NSA_IMPL_T: TypeAlias = Literal[
    "flashmla_sparse", "flashmla_kv", "fa3", "tilelang", "trtllm"
]

Prefill과 Decode에서 서로 다른 커널 구현체를 사용할 수 있다. 예를 들어 Prefill에서는 tilelang을, Decode에서는 flashmla_sparse를 사용하는 조합이 가능하다.

NSA 인덱서: Top-K 토큰 선택

NSA의 핵심은 인덱서(indexer)다. 인덱서는 각 query에 대해 가장 중요한 K 토큰의 인덱스를 선택한다.

class BaseIndexerMetadata(ABC):
    @abstractmethod
    def get_seqlens_int32(self) -> torch.Tensor: ...

    @abstractmethod
    def get_page_table_64(self) -> torch.Tensor: ...

    @abstractmethod
    def get_page_table_1(self) -> torch.Tensor: ...

    @abstractmethod
    def get_seqlens_expanded(self) -> torch.Tensor: ...

NSAIndexerMetadata는 이 추상 클래스를 구현하며, top-k 변환 방식을 TopkTransformMethod로 결정한다.

class TopkTransformMethod(IntEnum):
    PAGED = auto()   # top-k 인덱스 → 페이지 테이블 인덱스
    RAGGED = auto()  # top-k 인덱스 → ragged KV 인덱스

Fused Top-K Transform

def topk_transform(self, logits, topk, ks=None, ...):
    if not envs.SGLANG_NSA_FUSE_TOPK.get() or self.force_unfused_topk:
        return fast_topk_v2(logits, seq_lens_topk, topk, row_starts=ks)
    elif self.topk_transform_method == TopkTransformMethod.PAGED:
        return fast_topk_transform_fused(
            score=logits,
            lengths=seq_lens_topk,
            page_table_size_1=page_table_size_1,
            cu_seqlens_q=cu_seqlens_q_topk,
            topk=topk,
        )
    elif self.topk_transform_method == TopkTransformMethod.RAGGED:
        return fast_topk_transform_ragged_fused(
            score=logits,
            lengths=seq_lens_topk,
            topk_indices_offset=cu_topk_indices_offset,
            topk=topk,
        )

SGLANG_NSA_FUSE_TOPK가 활성화되면 top-k 선택과 인덱스 변환을 하나의 fused 커널로 수행한다. 이 최적화는 GPU 커널 launch 오버헤드를 줄인다.

NSAMetadata: Forward Batch 메타데이터

@dataclass(frozen=True)
class NSAMetadata:
    page_size: int
    cache_seqlens_int32: torch.Tensor
    max_seq_len_q: int
    max_seq_len_k: int
    cu_seqlens_q: torch.Tensor
    cu_seqlens_k: torch.Tensor
    page_table_1: torch.Tensor
    real_page_table: torch.Tensor

    # NSA 전용 메타데이터 (expanded)
    nsa_cache_seqlens_int32: torch.Tensor  # topk로 클리핑된 시퀀스 길이
    nsa_cu_seqlens_q: torch.Tensor
    nsa_cu_seqlens_k: torch.Tensor
    nsa_seqlens_expanded: torch.Tensor     # 확장된, 클리핑 전 시퀀스 길이

NSA 메타데이터는 두 가지 페이지 테이블을 관리한다. page_table_1은 page_size=1의 토큰 단위 테이블이고, real_page_table은 실제 페이지 크기(예: 64)의 블록 단위 테이블이다. Sparse attention은 토큰 단위 인덱싱이 필요하므로 page_table_1을 사용한다.

init_forward_metadata: 모드별 메타데이터

def init_forward_metadata(self, forward_batch):
    cache_seqlens_int32 = (forward_batch.seq_lens + draft_token_num).to(torch.int32)
    page_table = forward_batch.req_to_token_pool.req_to_token[
        forward_batch.req_pool_indices, :max_seqlen_k
    ]

    if forward_batch.forward_mode.is_decode_or_idle():
        max_seqlen_q = 1
        cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
        seqlens_expanded = cache_seqlens_int32
    elif forward_batch.forward_mode.is_extend():
        seqlens_expanded = torch.cat([
            torch.arange(kv_len - qo_len + 1, kv_len + 1, ...)
            for qo_len, kv_len in zip(extend_seq_lens_cpu, seq_lens_cpu)
        ])

Extend에서 seqlens_expanded가 특이하다. 각 query 토큰에 대해 해당 시점에서 볼 수 있는 KV 시퀀스 길이를 나열한다. 예를 들어 3개 토큰 extend, KV 길이 10이면 [8, 9, 10]이 된다. 이는 인덱서가 각 query 토큰별로 다른 범위에서 top-k를 선택하기 위함이다.

forward_decode: Sparse 어텐션 실행

def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True,
                   q_rope=None, k_rope=None, topk_indices=None, ...):
    # KV 캐시 저장
    if k is not None and save_kv_cache:
        forward_batch.token_to_kv_pool.set_mla_kv_buffer(
            layer, cache_loc, k, k_rope,
        )

    # Absorbed Q 준비
    q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
    q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim)

    # Top-K 인덱스 → 페이지 테이블 변환
    if envs.SGLANG_NSA_FUSE_TOPK.get():
        page_table_1 = topk_indices  # fused: 이미 변환됨
    else:
        page_table_1 = transform_index_page_table_decode(
            page_table=metadata.page_table_1,
            topk_indices=topk_indices,
            page_size=1,
        )

topk_indices는 모델의 인덱서 레이어에서 미리 계산된다. NSA 백엔드는 이 인덱스를 받아 해당 위치의 KV만 로드하여 sparse 어텐션을 수행한다.

커널 디스패치

    if self.nsa_decode_impl == "flashmla_sparse":
        q_all = concat_mla_absorb_q_general(q_nope, q_rope)
        return self._forward_flashmla_sparse(
            q_all=q_all, kv_cache=kv_cache,
            page_table_1=page_table_1, sm_scale=layer.scaling,
        )
    elif self.nsa_decode_impl == "flashmla_kv":
        q_all = concat_mla_absorb_q_general(q_nope, q_rope)
        return self._forward_flashmla_kv(q_all=q_all, ...)
    elif self.nsa_decode_impl == "fa3":
        return self._forward_fa3(q_nope=q_nope, q_rope=q_rope, ...)

NSA 서브모듈 구성

nsa/ 디렉토리는 NSA를 구성하는 전문 모듈들을 포함한다.

nsa/
├── nsa_indexer.py           # Top-K 인덱서 인터페이스
├── nsa_backend_mtp_precompute.py  # MTP 사전 계산 최적화
├── nsa_mtp_verification.py  # MTP 검증
├── transform_index.py       # 인덱스 → 페이지 테이블 변환
├── quant_k_cache.py         # K 캐시 FP8 양자화
├── dequant_k_cache.py       # K 캐시 역양자화
├── triton_kernel.py         # Triton 커널
├── tilelang_kernel.py       # TileLang 커널
├── index_buf_accessor.py    # 인덱스 버퍼 접근자
└── utils.py                 # CP 분할, 시퀀스 길이 계산

quant_k_cache.pydequant_k_cache.py는 NSA 전용 FP8 양자화를 구현한다. 인덱서가 score를 계산할 때는 FP8 K cache를 역양자화하여 사용하고, 실제 어텐션에서는 FP8을 직접 사용하여 메모리 대역폭을 절약한다.

Dense vs NSA 비교

항목 Dense Attention NSA
연산 복잡도 O(N * N) O(N * K), K << N
어텐션 범위 전체 시퀀스 Top-K 토큰만
추가 비용 없음 인덱서 오버헤드
정확도 정확 근사 (정보 손실 가능)
128K 시퀀스 매우 느림 K=512로 ~250x 연산 감소

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글