[SGLang] NSA (Narrow Sparse Attention): DeepSeek의 스파스 어텐션
들어가며
Dense attention은 모든 토큰 쌍에 대해 어텐션을 계산한다. 시퀀스 길이 N에 대해 O(N^2) 비용이 발생하며, 이는 긴 시퀀스에서 주요 병목이 된다. NSA(Narrow Sparse Attention)는 DeepSeek이 제안한 sparse attention 기법으로, 각 query에 대해 가장 중요한 K/V 토큰만 선택적으로 어텐션한다.
SGLang의 NativeSparseAttnBackend는 NSA를 위한 전용 백엔드로, 인덱서(indexer)를 통한 top-k 토큰 선택, FlashMLA/FA3/TileLang/TRT-LLM 등 다양한 sparse 커널, MLA와의 통합을 지원한다.
이 글에서는 python/sglang/srt/layers/attention/nsa_backend.py와 nsa/ 디렉토리를 분석한다.
전체 구조
NSA의 어텐션 흐름은 Dense Attention과 근본적으로 다르다.
Query (각 토큰)
│
▼
┌──────────────────────────────────────────────┐
│ NSA Indexer │
│ 1. 전체 KV에서 Score 계산 (DeepGEMM) │
│ 2. Top-K 인덱스 선택 (fast_topk) │
│ 3. 인덱스 → 페이지 테이블 변환 │
└──────────────┬───────────────────────────────┘
│ topk_indices
▼
┌──────────────────────────────────────────────┐
│ Sparse Attention Kernel │
│ │
│ ┌─ Decode ──────────────────────────────┐ │
│ │ flashmla_sparse: FlashMLA sparse │ │
│ │ flashmla_kv: FlashMLA + KV │ │
│ │ fa3: FlashAttention v3 │ │
│ │ trtllm: TRT-LLM kernel │ │
│ └───────────────────────────────────────┘ │
│ │
│ ┌─ Prefill ─────────────────────────────┐ │
│ │ flashmla_sparse/kv: FlashMLA │ │
│ │ fa3: FlashAttention v3 │ │
│ │ tilelang: TileLang kernel │ │
│ │ trtllm: TRT-LLM kernel │ │
│ └───────────────────────────────────────┘ │
└──────────────────────────────────────────────┘
NativeSparseAttnBackend 초기화
class NativeSparseAttnBackend(
NativeSparseAttnBackendMTPPrecomputeMixin, AttentionBackend
):
def __init__(self, model_runner, skip_prefill=False, ...):
self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
assert self.use_nsa, "NSA backend only supports DeepSeek NSA"
self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)
self.nsa_kv_cache_store_fp8 = (
model_runner.token_to_kv_pool.nsa_kv_cache_store_fp8
)
self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
NSA 백엔드는 DeepSeek NSA 모델 전용이다. nsa_index_topk는 각 query가 어텐션할 top-K 토큰 수를 결정한다. MLA의 compressed KV cache(kv_cache_dim = kv_lora_rank + qk_rope_head_dim)를 사용한다.
커널 구현체 선택
self.nsa_prefill_impl: _NSA_IMPL_T = model_runner.server_args.nsa_prefill_backend
self.nsa_decode_impl: _NSA_IMPL_T = model_runner.server_args.nsa_decode_backend
_NSA_IMPL_T: TypeAlias = Literal[
"flashmla_sparse", "flashmla_kv", "fa3", "tilelang", "trtllm"
]
Prefill과 Decode에서 서로 다른 커널 구현체를 사용할 수 있다. 예를 들어 Prefill에서는 tilelang을, Decode에서는 flashmla_sparse를 사용하는 조합이 가능하다.
NSA 인덱서: Top-K 토큰 선택
NSA의 핵심은 인덱서(indexer)다. 인덱서는 각 query에 대해 가장 중요한 K 토큰의 인덱스를 선택한다.
class BaseIndexerMetadata(ABC):
@abstractmethod
def get_seqlens_int32(self) -> torch.Tensor: ...
@abstractmethod
def get_page_table_64(self) -> torch.Tensor: ...
@abstractmethod
def get_page_table_1(self) -> torch.Tensor: ...
@abstractmethod
def get_seqlens_expanded(self) -> torch.Tensor: ...
NSAIndexerMetadata는 이 추상 클래스를 구현하며, top-k 변환 방식을 TopkTransformMethod로 결정한다.
class TopkTransformMethod(IntEnum):
PAGED = auto() # top-k 인덱스 → 페이지 테이블 인덱스
RAGGED = auto() # top-k 인덱스 → ragged KV 인덱스
Fused Top-K Transform
def topk_transform(self, logits, topk, ks=None, ...):
if not envs.SGLANG_NSA_FUSE_TOPK.get() or self.force_unfused_topk:
return fast_topk_v2(logits, seq_lens_topk, topk, row_starts=ks)
elif self.topk_transform_method == TopkTransformMethod.PAGED:
return fast_topk_transform_fused(
score=logits,
lengths=seq_lens_topk,
page_table_size_1=page_table_size_1,
cu_seqlens_q=cu_seqlens_q_topk,
topk=topk,
)
elif self.topk_transform_method == TopkTransformMethod.RAGGED:
return fast_topk_transform_ragged_fused(
score=logits,
lengths=seq_lens_topk,
topk_indices_offset=cu_topk_indices_offset,
topk=topk,
)
SGLANG_NSA_FUSE_TOPK가 활성화되면 top-k 선택과 인덱스 변환을 하나의 fused 커널로 수행한다. 이 최적화는 GPU 커널 launch 오버헤드를 줄인다.
NSAMetadata: Forward Batch 메타데이터
@dataclass(frozen=True)
class NSAMetadata:
page_size: int
cache_seqlens_int32: torch.Tensor
max_seq_len_q: int
max_seq_len_k: int
cu_seqlens_q: torch.Tensor
cu_seqlens_k: torch.Tensor
page_table_1: torch.Tensor
real_page_table: torch.Tensor
# NSA 전용 메타데이터 (expanded)
nsa_cache_seqlens_int32: torch.Tensor # topk로 클리핑된 시퀀스 길이
nsa_cu_seqlens_q: torch.Tensor
nsa_cu_seqlens_k: torch.Tensor
nsa_seqlens_expanded: torch.Tensor # 확장된, 클리핑 전 시퀀스 길이
NSA 메타데이터는 두 가지 페이지 테이블을 관리한다. page_table_1은 page_size=1의 토큰 단위 테이블이고, real_page_table은 실제 페이지 크기(예: 64)의 블록 단위 테이블이다. Sparse attention은 토큰 단위 인덱싱이 필요하므로 page_table_1을 사용한다.
init_forward_metadata: 모드별 메타데이터
def init_forward_metadata(self, forward_batch):
cache_seqlens_int32 = (forward_batch.seq_lens + draft_token_num).to(torch.int32)
page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :max_seqlen_k
]
if forward_batch.forward_mode.is_decode_or_idle():
max_seqlen_q = 1
cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
seqlens_expanded = cache_seqlens_int32
elif forward_batch.forward_mode.is_extend():
seqlens_expanded = torch.cat([
torch.arange(kv_len - qo_len + 1, kv_len + 1, ...)
for qo_len, kv_len in zip(extend_seq_lens_cpu, seq_lens_cpu)
])
Extend에서 seqlens_expanded가 특이하다. 각 query 토큰에 대해 해당 시점에서 볼 수 있는 KV 시퀀스 길이를 나열한다. 예를 들어 3개 토큰 extend, KV 길이 10이면 [8, 9, 10]이 된다. 이는 인덱서가 각 query 토큰별로 다른 범위에서 top-k를 선택하기 위함이다.
forward_decode: Sparse 어텐션 실행
def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True,
q_rope=None, k_rope=None, topk_indices=None, ...):
# KV 캐시 저장
if k is not None and save_kv_cache:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, cache_loc, k, k_rope,
)
# Absorbed Q 준비
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim)
# Top-K 인덱스 → 페이지 테이블 변환
if envs.SGLANG_NSA_FUSE_TOPK.get():
page_table_1 = topk_indices # fused: 이미 변환됨
else:
page_table_1 = transform_index_page_table_decode(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
page_size=1,
)
topk_indices는 모델의 인덱서 레이어에서 미리 계산된다. NSA 백엔드는 이 인덱스를 받아 해당 위치의 KV만 로드하여 sparse 어텐션을 수행한다.
커널 디스패치
if self.nsa_decode_impl == "flashmla_sparse":
q_all = concat_mla_absorb_q_general(q_nope, q_rope)
return self._forward_flashmla_sparse(
q_all=q_all, kv_cache=kv_cache,
page_table_1=page_table_1, sm_scale=layer.scaling,
)
elif self.nsa_decode_impl == "flashmla_kv":
q_all = concat_mla_absorb_q_general(q_nope, q_rope)
return self._forward_flashmla_kv(q_all=q_all, ...)
elif self.nsa_decode_impl == "fa3":
return self._forward_fa3(q_nope=q_nope, q_rope=q_rope, ...)
NSA 서브모듈 구성
nsa/ 디렉토리는 NSA를 구성하는 전문 모듈들을 포함한다.
nsa/
├── nsa_indexer.py # Top-K 인덱서 인터페이스
├── nsa_backend_mtp_precompute.py # MTP 사전 계산 최적화
├── nsa_mtp_verification.py # MTP 검증
├── transform_index.py # 인덱스 → 페이지 테이블 변환
├── quant_k_cache.py # K 캐시 FP8 양자화
├── dequant_k_cache.py # K 캐시 역양자화
├── triton_kernel.py # Triton 커널
├── tilelang_kernel.py # TileLang 커널
├── index_buf_accessor.py # 인덱스 버퍼 접근자
└── utils.py # CP 분할, 시퀀스 길이 계산
quant_k_cache.py와 dequant_k_cache.py는 NSA 전용 FP8 양자화를 구현한다. 인덱서가 score를 계산할 때는 FP8 K cache를 역양자화하여 사용하고, 실제 어텐션에서는 FP8을 직접 사용하여 메모리 대역폭을 절약한다.
Dense vs NSA 비교
| 항목 | Dense Attention | NSA |
|---|---|---|
| 연산 복잡도 | O(N * N) | O(N * K), K << N |
| 어텐션 범위 | 전체 시퀀스 | Top-K 토큰만 |
| 추가 비용 | 없음 | 인덱서 오버헤드 |
| 정확도 | 정확 | 근사 (정보 손실 가능) |
| 128K 시퀀스 | 매우 느림 | K=512로 ~250x 연산 감소 |
관련 포스트
참고
관련 포스트
- [SGLang] Sparsity Algorithms: QUEST와 DeepSeek NSA 희소 패턴
- [SGLang] Double Sparsity: H-Sparsity와 T-Sparsity의 이중 최적화
- [SGLang] Multi-head Latent Attention (MLA): KV 캐시 압축 어텐션
- [논문리뷰] HISA: Efficient Hierarchical Indexing for Fine-Grained Sparse Attention
- [논문리뷰] DeepSeek-V3.2: Pushing the Frontier of Open Large Language Models
SGLang 의 다른글
- 이전글 [SGLang] Multi-head Latent Attention (MLA): KV 캐시 압축 어텐션
- 현재글 : [SGLang] NSA (Narrow Sparse Attention): DeepSeek의 스파스 어텐션
- 다음글 [SGLang] Double Sparsity: H-Sparsity와 T-Sparsity의 이중 최적화
댓글