본문으로 건너뛰기

[SGLang] N-gram Draft: 모델 프리 투기적 디코딩

들어가며

N-gram Draft는 별도의 드래프트 모델 없이 토큰 시퀀스의 N-gram 통계만으로 다음 토큰을 추측하는 speculative decoding 방식이다. GPU 메모리를 추가로 소비하지 않으면서도 반복적인 패턴이 많은 워크로드(코드 생성, 정형 텍스트 등)에서 유효한 가속을 제공한다. SGLang의 NGRAMWorker는 C++ 기반 NgramCorpus와 트리 마스크 재구성 커널을 결합하여 이를 구현한다.

구조도

┌──────────────────────────────────────────────────┐
│                  NGRAMWorker                       │
│                                                    │
│  ┌──────────────────────────────────────────────┐ │
│  │              NgramCorpus (C++)                │ │
│  │  ┌─────────────┐   ┌───────────────────────┐│ │
│  │  │ Trie 구조    │   │ External Corpus (SAM) ││ │
│  │  │ batch_put()  │   │ load / remove         ││ │
│  │  │ batch_get()  │   └───────────────────────┘│ │
│  │  └─────────────┘                             │ │
│  └──────────────────────────────────────────────┘ │
│         │                                         │
│         ▼ draft_tokens + tree_mask                │
│  ┌──────────────────────────────────────────────┐ │
│  │  reconstruct_indices_from_tree_mask           │ │
│  │  (sgl_kernel CUDA kernel)                     │ │
│  │  → positions, retrive_index,                  │ │
│  │    retrive_next_token, retrive_next_sibling   │ │
│  └──────────────────────────────────────────────┘ │
│         │                                         │
│         ▼                                         │
│  ┌──────────────────────────────────────────────┐ │
│  │           NgramVerifyInput                    │ │
│  │  → Target Model forward (TARGET_VERIFY)       │ │
│  │  → verify (greedy / sampling)                 │ │
│  └──────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────┘

핵심 코드 분석

1. NGRAMWorker 초기화

python/sglang/srt/speculative/ngram_worker.py에서 NGRAMWorker는 타겟 워커를 직접 참조하며 별도 드래프트 모델을 로드하지 않는다.

class NGRAMWorker:
    def __init__(self, server_args, ..., target_worker):
        self.target_worker = target_worker
        self.model_runner = target_worker.model_runner
        self.draft_token_num = server_args.speculative_num_draft_tokens
        self.max_trie_depth = server_args.speculative_ngram_max_trie_depth

        self.ngram_corpus = NgramCorpus(
            min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth,
            max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth,
            match_type=server_args.speculative_ngram_match_type,
            capacity=server_args.speculative_ngram_capacity,
            max_trie_depth=server_args.speculative_ngram_max_trie_depth,
            draft_token_num=server_args.speculative_num_draft_tokens,
            external_sam_budget=server_args.speculative_ngram_external_sam_budget,
            external_corpus_max_tokens=server_args.speculative_ngram_external_corpus_max_tokens,
        )

NgramCorpus는 C++로 구현된 Trie 기반 코퍼스다. min_bfs_breadthmax_bfs_breadth가 트리 탐색 폭을 제어하여 드래프트 토큰의 다양성을 조절한다.

2. 외부 코퍼스 로딩

if server_args.speculative_ngram_external_corpus_path is not None:
    from sglang.srt.speculative.cpp_ngram.external_corpus import (
        iter_external_corpus_chunks,
    )
    corpus_path = server_args.speculative_ngram_external_corpus_path
    chunks = list(iter_external_corpus_chunks(
        corpus_path, target_worker.tokenizer,
        server_args.speculative_ngram_external_corpus_max_tokens,
    ))
    loaded = self.add_external_corpus(corpus_path, chunks)
    self.commit_corpus_load(corpus_path, loaded)

외부 텍스트 파일을 토큰화하여 Trie에 적재할 수 있다. 도메인 특화 코퍼스를 미리 로드하면 해당 도메인 텍스트 생성 시 acceptance rate가 크게 향상된다.

3. 드래프트 토큰 준비

def _prepare_draft_tokens(self, batch):
    self.ngram_corpus.synchronize()
    req_ids, batch_tokens, total_lens = [], [], []
    for req in batch.reqs:
        check_token = self._efficient_concat_last_n(
            req.origin_input_ids, req.output_ids, self.max_trie_depth
        )
        req_ids.append(req.rid)
        batch_tokens.append(check_token)
        total_lens.append(len(req.origin_input_ids) + len(req.output_ids))
    req_drafts, mask = self.ngram_corpus.batch_get(req_ids, batch_tokens, total_lens)
    return req_drafts, mask

_efficient_concat_last_n은 입력과 출력의 마지막 N 토큰만 효율적으로 추출한다. batch_get은 Trie에서 매칭되는 연속 토큰을 트리 구조로 반환하며, mask는 트리 마스크(어텐션 패턴)를 나타낸다.

4. 트리 인덱스 재구성

드래프트 토큰과 마스크에서 검증에 필요한 인덱스 구조를 CUDA 커널로 생성한다.

def _prepare_for_speculative_decoding(self, batch):
    req_drafts, mask = self._prepare_draft_tokens(batch)
    tree_mask.copy_(torch.from_numpy(mask), non_blocking=True)
    draft_tokens.copy_(torch.from_numpy(req_drafts), non_blocking=True)

    reconstruct_indices_from_tree_mask(
        tree_mask, batch.seq_lens,
        positions,            # mutable
        retrive_index,        # mutable
        retrive_next_token,   # mutable
        retrive_next_sibling, # mutable
        bs, self.draft_token_num,
    )

reconstruct_indices_from_tree_masksgl_kernel의 CUDA 커널로, 트리 마스크에서 DFS 순회용 인덱스(retrive_next_token, retrive_next_sibling)를 병렬로 계산한다.

5. Full Mask 구성

if USE_FULL_MASK:
    tree_mask = []
    mask = mask.reshape(batch.batch_size(), self.draft_token_num, self.draft_token_num)
    for i, req in enumerate(batch.reqs):
        seq_len = len(req.origin_input_ids) + len(req.output_ids)
        req_mask = torch.ones((self.draft_token_num, seq_len - 1)).cuda()
        req_mask = torch.cat(
            (req_mask, torch.from_numpy(mask[i]).cuda()), dim=1
        ).to(torch.bool)
        tree_mask.append(req_mask.flatten())
    tree_mask = torch.cat(tree_mask, dim=0)

Full mask 모드에서는 prefix 부분(모두 attend)과 draft 부분(트리 마스크)을 결합하여 완전한 attention mask를 구성한다.

6. 코퍼스 업데이트

검증 후 수락된 토큰을 코퍼스에 반영한다.

def _update_ngram_corpus(self, batch):
    batch_tokens = []
    for req in batch.reqs:
        put_ids = self._efficient_concat_last_n(
            req.origin_input_ids, req.output_ids, self.max_trie_depth
        )
        batch_tokens.append(put_ids)
    self.ngram_corpus.batch_put(batch_tokens)

생성된 텍스트가 즉시 코퍼스에 추가되어, 동일 세션 내에서 반복 패턴의 매칭 확률이 점진적으로 높아진다.

7. 완료 요청 정리

finished_req_ids = []
for req in batch.reqs:
    if req.finished() or req.is_retracted:
        finished_req_ids.append(req.rid)
if finished_req_ids:
    self.ngram_corpus.erase_match_state(finished_req_ids)

완료된 요청의 매칭 상태를 정리하여 메모리 누수를 방지한다.

EAGLE vs N-gram 비교

항목 EAGLE N-gram
드래프트 모델 별도 모델 필요 불필요
GPU 메모리 드래프트 모델 메모리 소비 거의 없음
Acceptance Rate 높음 (70-90%) 중간 (워크로드 의존)
적합한 워크로드 범용 반복 패턴 많은 텍스트
Overlap V2 지원 아니오

설계 근거

모델 프리 접근: 드래프트 모델이 필요 없으므로 배포가 단순하고, 기존 서빙 인프라에 최소 변경으로 적용 가능하다.

C++ Trie 구현: 고빈도 매칭/삽입이 필요한 N-gram 연산을 C++로 구현하여 Python 오버헤드를 제거했다.

텐서 사전 할당: _init_preallocated_tensors에서 최대 배치 크기 기준으로 텐서를 미리 할당하여 런타임 메모리 할당을 방지한다.

def _init_preallocated_tensors(self):
    max_total_drafts = self.max_batch_size * self.draft_token_num
    self.draft_tokens = torch.empty((max_total_drafts,), dtype=torch.int64, device=self.device)
    self.positions = torch.empty((max_total_drafts,), dtype=torch.int64, device=self.device)
    # batch size별 슬라이스를 미리 생성
    for bs in range(0, self.max_batch_size + 1):
        self.draft_tokens_batch.append(self.draft_tokens[: bs * self.draft_token_num])

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글