본문으로 건너뛰기

[SGLang] Double Sparsity: H-Sparsity와 T-Sparsity의 이중 최적화

들어가며

Dense attention은 모든 head에서 모든 토큰에 어텐션한다. 하지만 실제 LLM의 어텐션 패턴을 관찰하면 두 가지 희소성이 발견된다. 첫째, Head-level Sparsity(H-Sparsity) -- 일부 head만 현재 query에 유의미한 기여를 한다. 둘째, Token-level Sparsity(T-Sparsity) -- 각 head 내에서도 일부 토큰만 높은 어텐션 가중치를 가진다. Double Sparsity는 이 두 가지 희소성을 동시에 활용하여 어텐션 연산을 가속한다.

이 글에서는 python/sglang/srt/layers/attention/double_sparsity_backend.py를 분석한다.

전체 구조

Double Sparsity는 2단계로 동작한다. 먼저 경량 approximate attention으로 중요 토큰을 선별하고, 선별된 토큰에 대해서만 정밀 attention을 수행한다.

  Decode 단계 (시퀀스 길이 > sparse_decode_threshold)
  ═══════════════════════════════════════════════

  1단계: Approximate Attention (T-Sparsity)
  ─────────────────────────────────────────
  q_label = q의 sorted_channels 기준 추출
  k_label = KV cache에 저장된 key의 label

      q_label                k_label (전체 KV)
     [h, d']              [h, N, d']  (d' << d)
         │                    │
         └──── 내적 ────────── ┘
                 │
                 ▼
         att_out_approx [h, N]   ← 각 토큰의 대략적 중요도
                 │
                 ▼
         Top-K 토큰 선택 (heavy_token_num개)

  2단계: Precise Sparse Attention
  ─────────────────────────────────────────
  선택된 heavy_token_num개 토큰에 대해서만
  full attention 수행 (BLOCK_SEQ=128 단위)

      q [h, d]          k_selected [h, K, d]
         │                    │
         └──── 정밀 어텐션 ──── ┘
                 │
                 ▼
           output [h, d]

DoubleSparseAttnBackend 초기화

class DoubleSparseAttnBackend(AttentionBackend):
    def __init__(self, model_runner: ModelRunner):
        from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
            extend_attention_fwd,
            flash_decode_attention_fwd,
            flash_decode_sparse_attention_fwd,
        )
        self.decode_attention_fwd = flash_decode_attention_fwd
        self.decode_sparse_attention_fwd = flash_decode_sparse_attention_fwd
        self.extend_attention_fwd = extend_attention_fwd

        self.num_head = model_runner.model_config.num_attention_heads
        self.head_dim = model_runner.model_config.hidden_size // self.num_head
        self.heavy_token_num = model_runner.server_args.ds_heavy_token_num
        self.sorted_channels = model_runner.sorted_channels
        self.sparse_decode_threshold = (
            model_runner.server_args.ds_sparse_decode_threshold
        )

세 가지 Triton 커널을 사용한다. flash_decode_attention_fwd는 일반 dense decode, flash_decode_sparse_attention_fwd는 sparse decode, extend_attention_fwd는 prefill용이다. sorted_channels는 각 레이어의 채널 중요도 순서를 미리 계산한 텐서다.

핵심 파라미터

파라미터 설명
heavy_token_num Top-K로 선택할 중요 토큰 수 (ds_heavy_token_num)
sparse_decode_threshold Sparse decode를 활성화하는 최소 시퀀스 길이
sorted_channels 레이어별 채널 중요도 순서 (사전 계산)
BLOCK_SEQ Sparse attention의 블록 크기 (128)

init_forward_metadata: 메타데이터 초기화

Decode 모드에서는 approximate attention과 sparse attention을 위한 버퍼를 할당한다.

def init_forward_metadata(self, forward_batch: ForwardBatch):
    if forward_batch.forward_mode.is_decode():
        start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
        start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)

        total_num_tokens = torch.sum(forward_batch.seq_lens).item()
        attn_logits = torch.empty(
            (self.num_head, total_num_tokens),
            dtype=self.reduce_dtype, device="cuda",
        )

        att_out_approx = torch.empty(
            [self.num_head, bsz, max_seq_len],
            dtype=self.reduce_dtype, device="cuda",
        )

        block_seq_num = (self.heavy_token_num + self.BLOCK_SEQ - 1) // self.BLOCK_SEQ
        mid_out = torch.empty(
            [bsz, self.num_head, block_seq_num, self.head_dim],
            dtype=torch.float32, device="cuda",
        )
        mid_o_logexpsum = torch.empty(
            [bsz, self.num_head, block_seq_num],
            dtype=torch.float32, device="cuda",
        )

att_out_approx는 각 head에서 각 KV 토큰의 approximate score를 저장한다. mid_outmid_o_logexpsum은 block 단위 sparse attention 결과를 임시 저장하며, 최종적으로 online softmax로 합산된다.

forward_extend: Prefill 경로

Extend에서는 K의 label을 계산하여 KV 캐시와 함께 저장한다.

def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True):
    o = torch.empty_like(q)

    k_label = torch.gather(
        k, 2,
        self.sorted_channels[layer.layer_id]
            .unsqueeze(0).expand(k.shape[0], -1, -1),
    )

    if save_kv_cache:
        forward_batch.token_to_kv_pool.set_kv_buffer(
            layer, forward_batch.out_cache_loc, k, v, k_label
        )

sorted_channels[layer_id]는 해당 레이어에서 가장 중요한 채널의 인덱스를 담고 있다. torch.gather로 K에서 이 채널들의 값만 추출하여 k_label을 만든다. 이 k_label은 KV 캐시에 K, V와 함께 저장되어 나중에 approximate attention에서 사용된다.

Extend 자체는 dense attention으로 수행한다. 스파스 최적화는 Decode에서만 적용된다.

forward_decode: 이중 스파스 어텐션

Decode의 핵심은 시퀀스 길이에 따른 분기다.

def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True):
    k_label = torch.gather(
        k, 2,
        self.sorted_channels[layer.layer_id]
            .unsqueeze(0).expand(k.shape[0], -1, -1),
    )

    if save_kv_cache:
        forward_batch.token_to_kv_pool.set_kv_buffer(
            layer, forward_batch.out_cache_loc, k, v, k_label
        )

먼저 현재 토큰의 k_label을 계산하고 캐시에 저장한다.

분기: Dense vs Sparse

    if (min_seq_len < self.heavy_token_num
        or max_seq_len < self.sparse_decode_threshold):
        # Dense decode: 일반 flash decode
        self.decode_attention_fwd(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
            ...
        )
    else:
        # Sparse decode: 이중 스파스 어텐션
        q_label = torch.gather(
            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
            2,
            self.sorted_channels[layer.layer_id]
                .unsqueeze(0).expand(q.shape[0], -1, -1),
        )
        self.decode_sparse_attention_fwd(
            q.view(...),
            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
            o.view(...),
            q_label,
            forward_batch.token_to_kv_pool.get_label_buffer(layer.layer_id),
            ds_req_to_token,
            forward_batch.seq_lens,
            max_seq_len,
            layer.scaling,
            layer.logit_cap,
            self.heavy_token_num,
            self.att_out_approx,
            self.mid_out,
            self.mid_o_logexpsum,
            self.BLOCK_SEQ,
        )

Dense 조건: 배치 내 최소 시퀀스 길이가 heavy_token_num보다 작거나, 최대 시퀀스 길이가 sparse_decode_threshold보다 작으면 sparse 최적화의 이점이 없으므로 dense decode를 사용한다.

Sparse 조건: Q에서도 sorted_channels 기준으로 label을 추출한다. decode_sparse_attention_fwd는 내부적으로 다음을 수행한다.

  1. q_labelk_label의 내적으로 각 KV 토큰의 approximate score를 계산
  2. Score 기준 top-K(heavy_token_num) 토큰 선택
  3. 선택된 토큰에 대해서만 full precision 어텐션 수행
  4. BLOCK_SEQ=128 단위로 블록화된 결과를 mid_out, mid_o_logexpsum에 저장
  5. Online softmax로 최종 결과 합산

sorted_channels: 채널 중요도

sorted_channels는 서버 시작 시 모델 가중치를 분석하여 사전 계산된다. 각 레이어의 K projection에서 채널별 분산이 큰 순서대로 정렬한 인덱스다. 분산이 큰 채널은 토큰 간 구별력이 높으므로, 적은 수의 채널만으로도 approximate attention의 정확도를 유지할 수 있다.

Dense vs Double Sparsity 비교

항목 Dense Attention Double Sparsity
Decode 연산량 O(h * N * d) O(h * N * d') + O(h * K * d)
d' (label dim) - d의 일부 (sorted channels)
K (heavy tokens) N (전체) heavy_token_num (예: 256)
추가 저장 없음 k_label per token
활성화 조건 항상 seq_len > threshold
정확도 정확 근사 (heavy token 선택에 의존)
메모리 절감 없음 attn_logits 크기: N → K

시퀀스 길이 N=4096, heavy_token_num=256, d'=d/4 기준으로, approximate attention은 ~4x, precise attention은 ~16x 연산 감소 효과가 있다.

설계 근거

Double Sparsity가 Triton 백엔드의 옵션으로 구현된 이유가 있다. Attention Registry에서 triton 백엔드 선택 시 enable_double_sparsity 플래그로 분기한다.

@register_attention_backend("triton")
def create_triton_backend(runner):
    if runner.server_args.enable_double_sparsity:
        return DoubleSparseAttnBackend(runner)
    else:
        return TritonAttnBackend(runner)

이는 Double Sparsity의 Triton 커널이 기존 Triton 백엔드의 확장이기 때문이다. KV 캐시 인터페이스를 공유하면서 k_label 저장만 추가한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글