본문으로 건너뛰기

[SGLang] Sparsity Algorithms: QUEST와 DeepSeek NSA 희소 패턴

들어가며

긴 컨텍스트 추론에서 KV 캐시의 크기는 시퀀스 길이에 비례하여 증가한다. SGLang의 Sparsity 모듈은 디코드 단계에서 중요한 KV 엔트리만 선별하여 어텐션을 수행함으로써 메모리와 계산을 절약한다. QUEST(Query-Aware Sparsity)와 DeepSeek NSA(Native Sparse Attention) 두 가지 알고리즘을 제공한다.

구조도

mem_cache/sparsity/
├── factory.py              ── 생성 함수, 레지스트리
├── core/
│   └── sparse_coordinator.py  ── SparseCoordinator + RequestTrackers
├── algorithms/
│   ├── base_algorithm.py      ── BaseSparseAlgorithm (인터페이스)
│   ├── quest_algorithm.py     ── QUEST 페이지 기반 희소성
│   └── deepseek_nsa.py        ── DeepSeek NSA 네이티브 인덱서
└── backend/
    └── backend_adaptor.py     ── FlashAttention / NSA 어댑터

[요청 생명주기]
Request Start ──► Prefill ──► Decode (반복) ──► Request End
     │              │           │                    │
 on_request    attention   attention_begin      on_request
   _begin        _end     + attention_end          _end

핵심 코드 분석

SparseCoordinator: 생명주기 관리

SparseCoordinator는 요청의 전체 생명주기를 관리한다. Prefill에서 표현(representation)을 구축하고, Decode에서 매 어텐션마다 중요한 KV를 검색한다.

class SparseCoordinator:
    def __init__(self, config, algorithm, backend_adaptor,
                 req_to_token_pool, token_to_kv_pool,
                 start_layer, end_layer, device):
        self.algorithm = algorithm
        self.backend_adaptor = backend_adaptor
        self.states = RequestTrackers(
            req_to_token_pool.req_to_token.shape[0],
            device, end_layer - start_layer + 1, ...)

핵심 API는 4개다.

  • on_request_begin/end: 요청 등록/해제
  • attention_begin: TopK 검색 + 메타데이터 적응
  • attention_end: 표현 구축/업데이트

QUEST 알고리즘: 바운딩 박스 기반 페이지 선택

QUEST는 각 KV 페이지에 대해 키 값의 min/max를 유지한다. 쿼리가 주어지면 실제 dot product 없이 상한(upper bound)을 계산하여 중요한 페이지를 선택한다.

class QuestAlgorithm(BaseSparseAlgorithmImpl):
    def __init__(self, config, device, **kwargs):
        self.page_k_min = {}   # layer_id -> [num_pages, heads, dim]
        self.page_k_max = {}
        self.page_valid = {}

표현 구축 시 각 페이지의 키 값 범위를 계산한다.

def _compute_page_representations(self, layer_id, reqs, seq_lens, ...):
    keys = k_buffer[phys_tok].to(torch.float32)
    mask = tok_mask.unsqueeze(-1).unsqueeze(-1)
    page_min = torch.where(mask, keys, torch.full_like(keys, float("inf"))).amin(dim=2)
    page_max = torch.where(mask, keys, torch.full_like(keys, float("-inf"))).amax(dim=2)

검색 시 criticality = where(q >= 0, q * k_max, q * k_min)으로 어텐션 점수 상한을 추정한다.

def _retrieve_page_scores(self, layer_id, phys_pages, req_pool_indices, queries):
    k_min = self.page_k_min[layer_id][phys_pages_clamped]
    k_max = self.page_k_max[layer_id][phys_pages_clamped]
    criticality = torch.where(q >= 0, q * k_max, q * k_min).sum(dim=(2, 3))
    criticality = torch.where(valid_mask, criticality, float("-inf"))
    return criticality

이 방식은 GQA/MQA에서도 동작한다. 쿼리 헤드 수가 KV 헤드 수보다 많으면 그룹 평균을 취한다.

DeepSeek NSA: 네이티브 인덱서 위임

DeepSeek NSA는 모델 자체에 내장된 sparse indexer를 사용한다. 따라서 별도의 표현 구축이 필요 없고, retrieve_topk에서 indexer를 직접 호출한다.

class DeepSeekNSAAlgorithm(BaseSparseAlgorithmImpl):
    def retrieve_topk(self, queries, layer_id, ..., **kwargs):
        indexer = kwargs.get("indexer")
        return (
            indexer(x=x, q_lora=q_lora, positions=positions,
                    forward_batch=forward_batch, layer_id=layer_id),
            None,
        )

    def initialize_representation_pool(self, ...):
        pass  # NSA는 자체 표현 사용

    def construct_representations(self, ...):
        pass  # 구축 불필요

팩토리와 설정 파싱

factory.py는 JSON 문자열에서 sparse 설정을 파싱하고, 알고리즘과 백엔드 어댑터를 생성한다.

_ALGORITHM_REGISTRY = {
    "quest": lambda config, device, **kw: QuestAlgorithm(config, device, **kw),
    "deepseek_nsa": lambda config, device, **kw: DeepSeekNSAAlgorithm(config, device, **kw),
}

def create_sparse_coordinator(device, req_to_token_pool, ...):
    config = _parse_sparse_config(server_args)
    algorithm = _create_sparse_algorithm(config, device)
    backend_adaptor = _create_backend_adaptor(config.backend, device, algorithm, ...)
    return SparseCoordinator(config=config, algorithm=algorithm, ...)

SparseConfig의 주요 필드는 top_k (선택할 페이지 수), device_buffer_size, page_size, min_sparse_prompt_len (최소 프롬프트 길이) 등이다.

QUEST vs NSA 비교

구분 QUEST DeepSeek NSA
표현 구축 런타임에 page min/max 계산 모델 내장 인덱서
검색 방식 바운딩 박스 상한 추정 학습된 인덱서
범용성 모든 Transformer 모델 DeepSeek 전용
추가 메모리 페이지당 min/max 저장 인덱서 가중치
백엔드 FlashAttention NSA 전용

관련 포스트

  • Model Configuration 시스템: 모델 설정 관리
  • Batch Overlap: 연산-통신 오버랩 최적화

참고

댓글

관련 포스트

SGLang 의 다른글