[SGLang] Double Sparsity: H-Sparsity와 T-Sparsity의 이중 최적화
들어가며
Dense attention은 모든 head에서 모든 토큰에 어텐션한다. 하지만 실제 LLM의 어텐션 패턴을 관찰하면 두 가지 희소성이 발견된다. 첫째, Head-level Sparsity(H-Sparsity) -- 일부 head만 현재 query에 유의미한 기여를 한다. 둘째, Token-level Sparsity(T-Sparsity) -- 각 head 내에서도 일부 토큰만 높은 어텐션 가중치를 가진다. Double Sparsity는 이 두 가지 희소성을 동시에 활용하여 어텐션 연산을 가속한다.
이 글에서는 python/sglang/srt/layers/attention/double_sparsity_backend.py를 분석한다.
전체 구조
Double Sparsity는 2단계로 동작한다. 먼저 경량 approximate attention으로 중요 토큰을 선별하고, 선별된 토큰에 대해서만 정밀 attention을 수행한다.
Decode 단계 (시퀀스 길이 > sparse_decode_threshold)
═══════════════════════════════════════════════
1단계: Approximate Attention (T-Sparsity)
─────────────────────────────────────────
q_label = q의 sorted_channels 기준 추출
k_label = KV cache에 저장된 key의 label
q_label k_label (전체 KV)
[h, d'] [h, N, d'] (d' << d)
│ │
└──── 내적 ────────── ┘
│
▼
att_out_approx [h, N] ← 각 토큰의 대략적 중요도
│
▼
Top-K 토큰 선택 (heavy_token_num개)
2단계: Precise Sparse Attention
─────────────────────────────────────────
선택된 heavy_token_num개 토큰에 대해서만
full attention 수행 (BLOCK_SEQ=128 단위)
q [h, d] k_selected [h, K, d]
│ │
└──── 정밀 어텐션 ──── ┘
│
▼
output [h, d]
DoubleSparseAttnBackend 초기화
class DoubleSparseAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
extend_attention_fwd,
flash_decode_attention_fwd,
flash_decode_sparse_attention_fwd,
)
self.decode_attention_fwd = flash_decode_attention_fwd
self.decode_sparse_attention_fwd = flash_decode_sparse_attention_fwd
self.extend_attention_fwd = extend_attention_fwd
self.num_head = model_runner.model_config.num_attention_heads
self.head_dim = model_runner.model_config.hidden_size // self.num_head
self.heavy_token_num = model_runner.server_args.ds_heavy_token_num
self.sorted_channels = model_runner.sorted_channels
self.sparse_decode_threshold = (
model_runner.server_args.ds_sparse_decode_threshold
)
세 가지 Triton 커널을 사용한다. flash_decode_attention_fwd는 일반 dense decode, flash_decode_sparse_attention_fwd는 sparse decode, extend_attention_fwd는 prefill용이다. sorted_channels는 각 레이어의 채널 중요도 순서를 미리 계산한 텐서다.
핵심 파라미터
| 파라미터 | 설명 |
|---|---|
heavy_token_num |
Top-K로 선택할 중요 토큰 수 (ds_heavy_token_num) |
sparse_decode_threshold |
Sparse decode를 활성화하는 최소 시퀀스 길이 |
sorted_channels |
레이어별 채널 중요도 순서 (사전 계산) |
BLOCK_SEQ |
Sparse attention의 블록 크기 (128) |
init_forward_metadata: 메타데이터 초기화
Decode 모드에서는 approximate attention과 sparse attention을 위한 버퍼를 할당한다.
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode():
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
attn_logits = torch.empty(
(self.num_head, total_num_tokens),
dtype=self.reduce_dtype, device="cuda",
)
att_out_approx = torch.empty(
[self.num_head, bsz, max_seq_len],
dtype=self.reduce_dtype, device="cuda",
)
block_seq_num = (self.heavy_token_num + self.BLOCK_SEQ - 1) // self.BLOCK_SEQ
mid_out = torch.empty(
[bsz, self.num_head, block_seq_num, self.head_dim],
dtype=torch.float32, device="cuda",
)
mid_o_logexpsum = torch.empty(
[bsz, self.num_head, block_seq_num],
dtype=torch.float32, device="cuda",
)
att_out_approx는 각 head에서 각 KV 토큰의 approximate score를 저장한다. mid_out과 mid_o_logexpsum은 block 단위 sparse attention 결과를 임시 저장하며, 최종적으로 online softmax로 합산된다.
forward_extend: Prefill 경로
Extend에서는 K의 label을 계산하여 KV 캐시와 함께 저장한다.
def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True):
o = torch.empty_like(q)
k_label = torch.gather(
k, 2,
self.sorted_channels[layer.layer_id]
.unsqueeze(0).expand(k.shape[0], -1, -1),
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v, k_label
)
sorted_channels[layer_id]는 해당 레이어에서 가장 중요한 채널의 인덱스를 담고 있다. torch.gather로 K에서 이 채널들의 값만 추출하여 k_label을 만든다. 이 k_label은 KV 캐시에 K, V와 함께 저장되어 나중에 approximate attention에서 사용된다.
Extend 자체는 dense attention으로 수행한다. 스파스 최적화는 Decode에서만 적용된다.
forward_decode: 이중 스파스 어텐션
Decode의 핵심은 시퀀스 길이에 따른 분기다.
def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True):
k_label = torch.gather(
k, 2,
self.sorted_channels[layer.layer_id]
.unsqueeze(0).expand(k.shape[0], -1, -1),
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v, k_label
)
먼저 현재 토큰의 k_label을 계산하고 캐시에 저장한다.
분기: Dense vs Sparse
if (min_seq_len < self.heavy_token_num
or max_seq_len < self.sparse_decode_threshold):
# Dense decode: 일반 flash decode
self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
...
)
else:
# Sparse decode: 이중 스파스 어텐션
q_label = torch.gather(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
2,
self.sorted_channels[layer.layer_id]
.unsqueeze(0).expand(q.shape[0], -1, -1),
)
self.decode_sparse_attention_fwd(
q.view(...),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(...),
q_label,
forward_batch.token_to_kv_pool.get_label_buffer(layer.layer_id),
ds_req_to_token,
forward_batch.seq_lens,
max_seq_len,
layer.scaling,
layer.logit_cap,
self.heavy_token_num,
self.att_out_approx,
self.mid_out,
self.mid_o_logexpsum,
self.BLOCK_SEQ,
)
Dense 조건: 배치 내 최소 시퀀스 길이가 heavy_token_num보다 작거나, 최대 시퀀스 길이가 sparse_decode_threshold보다 작으면 sparse 최적화의 이점이 없으므로 dense decode를 사용한다.
Sparse 조건: Q에서도 sorted_channels 기준으로 label을 추출한다. decode_sparse_attention_fwd는 내부적으로 다음을 수행한다.
q_label과k_label의 내적으로 각 KV 토큰의 approximate score를 계산- Score 기준 top-K(
heavy_token_num) 토큰 선택 - 선택된 토큰에 대해서만 full precision 어텐션 수행
BLOCK_SEQ=128단위로 블록화된 결과를mid_out,mid_o_logexpsum에 저장- Online softmax로 최종 결과 합산
sorted_channels: 채널 중요도
sorted_channels는 서버 시작 시 모델 가중치를 분석하여 사전 계산된다. 각 레이어의 K projection에서 채널별 분산이 큰 순서대로 정렬한 인덱스다. 분산이 큰 채널은 토큰 간 구별력이 높으므로, 적은 수의 채널만으로도 approximate attention의 정확도를 유지할 수 있다.
Dense vs Double Sparsity 비교
| 항목 | Dense Attention | Double Sparsity |
|---|---|---|
| Decode 연산량 | O(h * N * d) | O(h * N * d') + O(h * K * d) |
| d' (label dim) | - | d의 일부 (sorted channels) |
| K (heavy tokens) | N (전체) | heavy_token_num (예: 256) |
| 추가 저장 | 없음 | k_label per token |
| 활성화 조건 | 항상 | seq_len > threshold |
| 정확도 | 정확 | 근사 (heavy token 선택에 의존) |
| 메모리 절감 | 없음 | attn_logits 크기: N → K |
시퀀스 길이 N=4096, heavy_token_num=256, d'=d/4 기준으로, approximate attention은 ~4x, precise attention은 ~16x 연산 감소 효과가 있다.
설계 근거
Double Sparsity가 Triton 백엔드의 옵션으로 구현된 이유가 있다. Attention Registry에서 triton 백엔드 선택 시 enable_double_sparsity 플래그로 분기한다.
@register_attention_backend("triton")
def create_triton_backend(runner):
if runner.server_args.enable_double_sparsity:
return DoubleSparseAttnBackend(runner)
else:
return TritonAttnBackend(runner)
이는 Double Sparsity의 Triton 커널이 기존 Triton 백엔드의 확장이기 때문이다. KV 캐시 인터페이스를 공유하면서 k_label 저장만 추가한다.
관련 포스트
참고
관련 포스트
- [SGLang] NSA (Narrow Sparse Attention): DeepSeek의 스파스 어텐션
- [논문리뷰] LVSA: Training-Free Sparse Attention for Long Video Diffusion
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [논문리뷰] OSP-Next: Efficient High-Quality Video Generation with Sparse Sequence Parallelism, HiF8 Quantization, and Reinforcement Learning
SGLang 의 다른글
- 이전글 [SGLang] NSA (Narrow Sparse Attention): DeepSeek의 스파스 어텐션
- 현재글 : [SGLang] Double Sparsity: H-Sparsity와 T-Sparsity의 이중 최적화
- 다음글 [SGLang] Hybrid Attention: Dense-Sparse 동적 전환 전략
댓글