[SGLang] Sparsity Algorithms: QUEST와 DeepSeek NSA 희소 패턴
들어가며
긴 컨텍스트 추론에서 KV 캐시의 크기는 시퀀스 길이에 비례하여 증가한다. SGLang의 Sparsity 모듈은 디코드 단계에서 중요한 KV 엔트리만 선별하여 어텐션을 수행함으로써 메모리와 계산을 절약한다. QUEST(Query-Aware Sparsity)와 DeepSeek NSA(Native Sparse Attention) 두 가지 알고리즘을 제공한다.
구조도
mem_cache/sparsity/
├── factory.py ── 생성 함수, 레지스트리
├── core/
│ └── sparse_coordinator.py ── SparseCoordinator + RequestTrackers
├── algorithms/
│ ├── base_algorithm.py ── BaseSparseAlgorithm (인터페이스)
│ ├── quest_algorithm.py ── QUEST 페이지 기반 희소성
│ └── deepseek_nsa.py ── DeepSeek NSA 네이티브 인덱서
└── backend/
└── backend_adaptor.py ── FlashAttention / NSA 어댑터
[요청 생명주기]
Request Start ──► Prefill ──► Decode (반복) ──► Request End
│ │ │ │
on_request attention attention_begin on_request
_begin _end + attention_end _end
핵심 코드 분석
SparseCoordinator: 생명주기 관리
SparseCoordinator는 요청의 전체 생명주기를 관리한다. Prefill에서 표현(representation)을 구축하고, Decode에서 매 어텐션마다 중요한 KV를 검색한다.
class SparseCoordinator:
def __init__(self, config, algorithm, backend_adaptor,
req_to_token_pool, token_to_kv_pool,
start_layer, end_layer, device):
self.algorithm = algorithm
self.backend_adaptor = backend_adaptor
self.states = RequestTrackers(
req_to_token_pool.req_to_token.shape[0],
device, end_layer - start_layer + 1, ...)
핵심 API는 4개다.
on_request_begin/end: 요청 등록/해제attention_begin: TopK 검색 + 메타데이터 적응attention_end: 표현 구축/업데이트
QUEST 알고리즘: 바운딩 박스 기반 페이지 선택
QUEST는 각 KV 페이지에 대해 키 값의 min/max를 유지한다. 쿼리가 주어지면 실제 dot product 없이 상한(upper bound)을 계산하여 중요한 페이지를 선택한다.
class QuestAlgorithm(BaseSparseAlgorithmImpl):
def __init__(self, config, device, **kwargs):
self.page_k_min = {} # layer_id -> [num_pages, heads, dim]
self.page_k_max = {}
self.page_valid = {}
표현 구축 시 각 페이지의 키 값 범위를 계산한다.
def _compute_page_representations(self, layer_id, reqs, seq_lens, ...):
keys = k_buffer[phys_tok].to(torch.float32)
mask = tok_mask.unsqueeze(-1).unsqueeze(-1)
page_min = torch.where(mask, keys, torch.full_like(keys, float("inf"))).amin(dim=2)
page_max = torch.where(mask, keys, torch.full_like(keys, float("-inf"))).amax(dim=2)
검색 시 criticality = where(q >= 0, q * k_max, q * k_min)으로 어텐션 점수 상한을 추정한다.
def _retrieve_page_scores(self, layer_id, phys_pages, req_pool_indices, queries):
k_min = self.page_k_min[layer_id][phys_pages_clamped]
k_max = self.page_k_max[layer_id][phys_pages_clamped]
criticality = torch.where(q >= 0, q * k_max, q * k_min).sum(dim=(2, 3))
criticality = torch.where(valid_mask, criticality, float("-inf"))
return criticality
이 방식은 GQA/MQA에서도 동작한다. 쿼리 헤드 수가 KV 헤드 수보다 많으면 그룹 평균을 취한다.
DeepSeek NSA: 네이티브 인덱서 위임
DeepSeek NSA는 모델 자체에 내장된 sparse indexer를 사용한다. 따라서 별도의 표현 구축이 필요 없고, retrieve_topk에서 indexer를 직접 호출한다.
class DeepSeekNSAAlgorithm(BaseSparseAlgorithmImpl):
def retrieve_topk(self, queries, layer_id, ..., **kwargs):
indexer = kwargs.get("indexer")
return (
indexer(x=x, q_lora=q_lora, positions=positions,
forward_batch=forward_batch, layer_id=layer_id),
None,
)
def initialize_representation_pool(self, ...):
pass # NSA는 자체 표현 사용
def construct_representations(self, ...):
pass # 구축 불필요
팩토리와 설정 파싱
factory.py는 JSON 문자열에서 sparse 설정을 파싱하고, 알고리즘과 백엔드 어댑터를 생성한다.
_ALGORITHM_REGISTRY = {
"quest": lambda config, device, **kw: QuestAlgorithm(config, device, **kw),
"deepseek_nsa": lambda config, device, **kw: DeepSeekNSAAlgorithm(config, device, **kw),
}
def create_sparse_coordinator(device, req_to_token_pool, ...):
config = _parse_sparse_config(server_args)
algorithm = _create_sparse_algorithm(config, device)
backend_adaptor = _create_backend_adaptor(config.backend, device, algorithm, ...)
return SparseCoordinator(config=config, algorithm=algorithm, ...)
SparseConfig의 주요 필드는 top_k (선택할 페이지 수), device_buffer_size, page_size, min_sparse_prompt_len (최소 프롬프트 길이) 등이다.
QUEST vs NSA 비교
| 구분 | QUEST | DeepSeek NSA |
|---|---|---|
| 표현 구축 | 런타임에 page min/max 계산 | 모델 내장 인덱서 |
| 검색 방식 | 바운딩 박스 상한 추정 | 학습된 인덱서 |
| 범용성 | 모든 Transformer 모델 | DeepSeek 전용 |
| 추가 메모리 | 페이지당 min/max 저장 | 인덱서 가중치 |
| 백엔드 | FlashAttention | NSA 전용 |
관련 포스트
- Model Configuration 시스템: 모델 설정 관리
- Batch Overlap: 연산-통신 오버랩 최적화
참고
- 소스 코드:
python/sglang/srt/mem_cache/sparsity/ - QUEST: Tang et al., "QUEST: Query-Aware Sparsity for Efficient Long-Context LLM Inference" (2024)
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Deep GEMM Wrapper: 최적화 행렬 곱 라이브러리
- 현재글 : [SGLang] Sparsity Algorithms: QUEST와 DeepSeek NSA 희소 패턴
- 다음글 [SGLang] Batch Overlap: 연산-통신 오버랩 최적화
댓글