[SGLang] N-gram Draft: 모델 프리 투기적 디코딩
들어가며
N-gram Draft는 별도의 드래프트 모델 없이 토큰 시퀀스의 N-gram 통계만으로 다음 토큰을 추측하는 speculative decoding 방식이다. GPU 메모리를 추가로 소비하지 않으면서도 반복적인 패턴이 많은 워크로드(코드 생성, 정형 텍스트 등)에서 유효한 가속을 제공한다. SGLang의 NGRAMWorker는 C++ 기반 NgramCorpus와 트리 마스크 재구성 커널을 결합하여 이를 구현한다.
구조도
┌──────────────────────────────────────────────────┐
│ NGRAMWorker │
│ │
│ ┌──────────────────────────────────────────────┐ │
│ │ NgramCorpus (C++) │ │
│ │ ┌─────────────┐ ┌───────────────────────┐│ │
│ │ │ Trie 구조 │ │ External Corpus (SAM) ││ │
│ │ │ batch_put() │ │ load / remove ││ │
│ │ │ batch_get() │ └───────────────────────┘│ │
│ │ └─────────────┘ │ │
│ └──────────────────────────────────────────────┘ │
│ │ │
│ ▼ draft_tokens + tree_mask │
│ ┌──────────────────────────────────────────────┐ │
│ │ reconstruct_indices_from_tree_mask │ │
│ │ (sgl_kernel CUDA kernel) │ │
│ │ → positions, retrive_index, │ │
│ │ retrive_next_token, retrive_next_sibling │ │
│ └──────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────┐ │
│ │ NgramVerifyInput │ │
│ │ → Target Model forward (TARGET_VERIFY) │ │
│ │ → verify (greedy / sampling) │ │
│ └──────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────┘
핵심 코드 분석
1. NGRAMWorker 초기화
python/sglang/srt/speculative/ngram_worker.py에서 NGRAMWorker는 타겟 워커를 직접 참조하며 별도 드래프트 모델을 로드하지 않는다.
class NGRAMWorker:
def __init__(self, server_args, ..., target_worker):
self.target_worker = target_worker
self.model_runner = target_worker.model_runner
self.draft_token_num = server_args.speculative_num_draft_tokens
self.max_trie_depth = server_args.speculative_ngram_max_trie_depth
self.ngram_corpus = NgramCorpus(
min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth,
max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth,
match_type=server_args.speculative_ngram_match_type,
capacity=server_args.speculative_ngram_capacity,
max_trie_depth=server_args.speculative_ngram_max_trie_depth,
draft_token_num=server_args.speculative_num_draft_tokens,
external_sam_budget=server_args.speculative_ngram_external_sam_budget,
external_corpus_max_tokens=server_args.speculative_ngram_external_corpus_max_tokens,
)
NgramCorpus는 C++로 구현된 Trie 기반 코퍼스다. min_bfs_breadth와 max_bfs_breadth가 트리 탐색 폭을 제어하여 드래프트 토큰의 다양성을 조절한다.
2. 외부 코퍼스 로딩
if server_args.speculative_ngram_external_corpus_path is not None:
from sglang.srt.speculative.cpp_ngram.external_corpus import (
iter_external_corpus_chunks,
)
corpus_path = server_args.speculative_ngram_external_corpus_path
chunks = list(iter_external_corpus_chunks(
corpus_path, target_worker.tokenizer,
server_args.speculative_ngram_external_corpus_max_tokens,
))
loaded = self.add_external_corpus(corpus_path, chunks)
self.commit_corpus_load(corpus_path, loaded)
외부 텍스트 파일을 토큰화하여 Trie에 적재할 수 있다. 도메인 특화 코퍼스를 미리 로드하면 해당 도메인 텍스트 생성 시 acceptance rate가 크게 향상된다.
3. 드래프트 토큰 준비
def _prepare_draft_tokens(self, batch):
self.ngram_corpus.synchronize()
req_ids, batch_tokens, total_lens = [], [], []
for req in batch.reqs:
check_token = self._efficient_concat_last_n(
req.origin_input_ids, req.output_ids, self.max_trie_depth
)
req_ids.append(req.rid)
batch_tokens.append(check_token)
total_lens.append(len(req.origin_input_ids) + len(req.output_ids))
req_drafts, mask = self.ngram_corpus.batch_get(req_ids, batch_tokens, total_lens)
return req_drafts, mask
_efficient_concat_last_n은 입력과 출력의 마지막 N 토큰만 효율적으로 추출한다. batch_get은 Trie에서 매칭되는 연속 토큰을 트리 구조로 반환하며, mask는 트리 마스크(어텐션 패턴)를 나타낸다.
4. 트리 인덱스 재구성
드래프트 토큰과 마스크에서 검증에 필요한 인덱스 구조를 CUDA 커널로 생성한다.
def _prepare_for_speculative_decoding(self, batch):
req_drafts, mask = self._prepare_draft_tokens(batch)
tree_mask.copy_(torch.from_numpy(mask), non_blocking=True)
draft_tokens.copy_(torch.from_numpy(req_drafts), non_blocking=True)
reconstruct_indices_from_tree_mask(
tree_mask, batch.seq_lens,
positions, # mutable
retrive_index, # mutable
retrive_next_token, # mutable
retrive_next_sibling, # mutable
bs, self.draft_token_num,
)
reconstruct_indices_from_tree_mask는 sgl_kernel의 CUDA 커널로, 트리 마스크에서 DFS 순회용 인덱스(retrive_next_token, retrive_next_sibling)를 병렬로 계산한다.
5. Full Mask 구성
if USE_FULL_MASK:
tree_mask = []
mask = mask.reshape(batch.batch_size(), self.draft_token_num, self.draft_token_num)
for i, req in enumerate(batch.reqs):
seq_len = len(req.origin_input_ids) + len(req.output_ids)
req_mask = torch.ones((self.draft_token_num, seq_len - 1)).cuda()
req_mask = torch.cat(
(req_mask, torch.from_numpy(mask[i]).cuda()), dim=1
).to(torch.bool)
tree_mask.append(req_mask.flatten())
tree_mask = torch.cat(tree_mask, dim=0)
Full mask 모드에서는 prefix 부분(모두 attend)과 draft 부분(트리 마스크)을 결합하여 완전한 attention mask를 구성한다.
6. 코퍼스 업데이트
검증 후 수락된 토큰을 코퍼스에 반영한다.
def _update_ngram_corpus(self, batch):
batch_tokens = []
for req in batch.reqs:
put_ids = self._efficient_concat_last_n(
req.origin_input_ids, req.output_ids, self.max_trie_depth
)
batch_tokens.append(put_ids)
self.ngram_corpus.batch_put(batch_tokens)
생성된 텍스트가 즉시 코퍼스에 추가되어, 동일 세션 내에서 반복 패턴의 매칭 확률이 점진적으로 높아진다.
7. 완료 요청 정리
finished_req_ids = []
for req in batch.reqs:
if req.finished() or req.is_retracted:
finished_req_ids.append(req.rid)
if finished_req_ids:
self.ngram_corpus.erase_match_state(finished_req_ids)
완료된 요청의 매칭 상태를 정리하여 메모리 누수를 방지한다.
EAGLE vs N-gram 비교
| 항목 | EAGLE | N-gram |
|---|---|---|
| 드래프트 모델 | 별도 모델 필요 | 불필요 |
| GPU 메모리 | 드래프트 모델 메모리 소비 | 거의 없음 |
| Acceptance Rate | 높음 (70-90%) | 중간 (워크로드 의존) |
| 적합한 워크로드 | 범용 | 반복 패턴 많은 텍스트 |
| Overlap V2 지원 | 예 | 아니오 |
설계 근거
모델 프리 접근: 드래프트 모델이 필요 없으므로 배포가 단순하고, 기존 서빙 인프라에 최소 변경으로 적용 가능하다.
C++ Trie 구현: 고빈도 매칭/삽입이 필요한 N-gram 연산을 C++로 구현하여 Python 오버헤드를 제거했다.
텐서 사전 할당: _init_preallocated_tensors에서 최대 배치 크기 기준으로 텐서를 미리 할당하여 런타임 메모리 할당을 방지한다.
def _init_preallocated_tensors(self):
max_total_drafts = self.max_batch_size * self.draft_token_num
self.draft_tokens = torch.empty((max_total_drafts,), dtype=torch.int64, device=self.device)
self.positions = torch.empty((max_total_drafts,), dtype=torch.int64, device=self.device)
# batch size별 슬라이스를 미리 생성
for bs in range(0, self.max_batch_size + 1):
self.draft_tokens_batch.append(self.draft_tokens[: bs * self.draft_token_num])
관련 포스트
참고
- SGLang N-gram Worker 소스
- SGLang N-gram Info 소스
- Yang et al., "Inference with Reference: Lossless Acceleration of Large Language Models" (2023)
관련 포스트
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
- [논문리뷰] NGM: A Plug-and-Play Training-Free Memory Module for LLMs
- [sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서
SGLang 의 다른글
- 이전글 [SGLang] Multi-Layer EAGLE: 다계층 드래프트로 더 긴 추측
- 현재글 : [SGLang] N-gram Draft: 모델 프리 투기적 디코딩
- 다음글 [SGLang] DFlash: Flash 기반 고속 드래프팅
댓글