[SGLang] Tree Search & Verification: 트리 기반 추측과 검증
들어가며
Speculative decoding에서 드래프트 토큰을 트리 구조로 배치하면 하나의 타겟 모델 forward로 여러 후보 경로를 동시에 검증할 수 있다. 선형 체인(topk=1)은 하나의 경로만 검증하지만, 트리(topk>1)는 분기된 경로를 병렬로 평가하여 최소 하나 이상의 긴 수락 경로를 확보할 확률을 높인다. SGLang의 spec_utils.py와 eagle_utils.py에 이 트리 구축/검증 로직이 구현되어 있다.
구조도
[root: verified_id]
/ | \
[tok_A] [tok_B] [tok_C] ← step 0 (topk=3)
/ \ / \ / \
[tok_D] [tok_E] [tok_F] [tok_G] [tok_H] [tok_I] ← step 1
| | | | | |
[...] [...] [...] [...] [...] [...] ← step 2
retrive_next_token: 각 노드의 첫 번째 자식 인덱스
retrive_next_sibling: 각 노드의 다음 형제 인덱스
retrive_index: 검증 결과에서 수락 경로 추적용 인덱스
핵심 코드 분석
1. 트리 구축: organize_draft_results
python/sglang/srt/speculative/eagle_utils.py에서 드래프트 결과를 정리한다.
def organize_draft_results(score_list, token_list, parents_list, num_draft_token):
score_list = torch.cat(score_list, dim=1).flatten(1)
ss_token_list = torch.cat(token_list, dim=1)
top_scores = torch.topk(score_list, num_draft_token - 1, dim=-1)
top_scores_index = top_scores.indices
top_scores_index = torch.sort(top_scores_index).values
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
if len(parents_list) > 1:
parent_list = torch.cat(parents_list[:-1], dim=1)
else:
parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
return parent_list, top_scores_index, draft_tokens
모든 step의 score를 합쳐 상위 num_draft_token - 1개를 선택한다. -1인 이유는 root(verified_id)가 이미 하나를 차지하기 때문이다. top_scores_index를 정렬하여 트리 구조를 위치 기반으로 일관되게 유지한다.
2. 트리 마스크 생성: build_tree_kernel_efficient
def build_tree_kernel_efficient(
verified_id, parent_list, top_scores_index, draft_tokens,
seq_lens, seq_lens_sum, topk, spec_steps, num_verify_tokens,
tree_mask_mode=TreeMaskMode.FULL_MASK, ...
):
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
if tree_mask_mode == TreeMaskMode.FULL_MASK:
tree_mask = torch.full(
(seq_lens_sum * num_verify_tokens + num_verify_tokens * num_verify_tokens * bs,),
True, dtype=torch.bool, device=device,
)
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
tree_mask = torch.zeros(
(num_verify_tokens * bs,), dtype=packed_dtypes[packed_dtype_idx], device=device,
)
트리 마스크는 3가지 모드를 지원한다:
FULL_MASK: prefix + draft 영역 전체를 포함하는 완전 마스크QLEN_ONLY: draft 영역만 포함하는 축소 마스크 (FlashInfer 호환)QLEN_ONLY_BITPACKING: bitwise packing으로 메모리 절약
sgl_build_tree_kernel_efficient CUDA 커널이 parent_list에서 트리 마스크, position, retrieve 인덱스를 한 번에 계산한다.
3. DFS 기반 Grammar 트리 순회
구조화된 출력(grammar)을 지원하기 위해 DFS로 트리를 순회한다.
def traverse_tree(retrieve_next_token, retrieve_next_sibling, draft_tokens, grammar, allocate_token_bitmask, vocab_size):
def dfs(curr, retrieve_next_token, retrieve_next_sibling, parent_pos):
if curr == 0:
accepted = True
else:
parent_bitmask = allocate_token_bitmask[parent_pos]
curr_token_id = draft_tokens[curr]
accepted = (parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))) != 0
if accepted:
if curr != 0:
grammar.accept_token(draft_tokens[curr])
if not grammar.is_terminated():
grammar.fill_vocab_mask(allocate_token_bitmask, curr)
if retrieve_next_token[curr] != -1:
dfs(retrieve_next_token[curr], ...)
if curr != 0:
grammar.rollback(1)
if retrieve_next_sibling[curr] != -1:
dfs(retrieve_next_sibling[curr], ..., parent_pos)
dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
retrieve_next_token이 자식, retrieve_next_sibling이 형제를 가리키는 이진 트리 표현을 사용한다. Grammar에서 수락된 토큰은 accept_token으로 상태를 전진시키고, 거부 시 rollback으로 되돌린다. 각 위치에서 fill_vocab_mask로 다음 허용 토큰의 bitmask를 생성한다.
4. Greedy 트리 검증
eagle_utils.py의 verify_tree_greedy_func는 CUDA 커널 기반 greedy 검증을 수행한다.
predict, accept_index, accept_length = verify_tree_greedy_func(
predicts=predict, # 출력: 각 위치의 예측 토큰
accept_index=accept_index, # 출력: 수락된 인덱스
accept_token_num=accept_length,
candidates=candidates, # 드래프트 토큰
retrive_index=self.retrive_index,
retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling,
target_predict=target_predict, # 타겟 모델의 argmax
topk=self.topk,
)
retrive_next_token과 retrive_next_sibling을 따라 트리를 순회하며, 타겟 예측과 드래프트 토큰이 일치하면 수락하고 자식으로 내려간다. 불일치 시 형제로 이동한다.
5. Sampling 기반 트리 검증
non-greedy 샘플링 시에는 rejection sampling을 사용한다.
tree_speculative_sampling_target_only(
predicts=predict,
accept_index=accept_index,
accept_token_num=accept_length,
candidates=candidates,
retrive_index=self.retrive_index,
retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling,
uniform_samples=coins,
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=server_args.speculative_accept_threshold_single,
threshold_acc=server_args.speculative_accept_threshold_acc,
deterministic=True,
)
threshold_single과 threshold_acc가 단일 토큰 및 누적 수락 임계값을 제어한다. uniform_samples는 rejection sampling용 난수, coins_for_final_sampling은 마지막 bonus 토큰 샘플링용이다.
6. 시뮬레이션 모드
벤치마킹을 위한 acceptance length 시뮬레이션도 지원한다.
def generate_simulated_accept_index(accept_index, predict, accept_length, bs, spec_steps,
simulate_acc_len, simulate_acc_method):
if simulate_acc_method == "match-expected":
lower = int(simulate_acc_len // 1)
upper = lower + 1 if lower < spec_steps + 1 else lower
weight_upper = simulate_acc_len - lower
probs = torch.tensor([1 - weight_upper, weight_upper])
sampled_index = torch.multinomial(probs, num_samples=1)
simulate_acc_len = lower if sampled_index == 0 else upper
SGLANG_SIMULATE_ACC_LEN 환경변수로 원하는 평균 acceptance length를 설정하여, 실제 모델 없이도 speculative decoding의 스케줄링 성능을 테스트할 수 있다.
7. TP 동기화
멀티 GPU에서 검증 결과가 일관되도록 broadcast한다.
tp_group = get_attention_tp_group() if is_dp_attention_enabled() else get_tp_group()
if tp_group.world_size > 1:
tp_group.broadcast(predict, src=0)
tp_group.broadcast(accept_index, src=0)
tp_group.broadcast(accept_length, src=0)
softmax의 floating-point 비결정성으로 인해 GPU마다 다른 토큰을 샘플링할 수 있으므로, rank 0의 결과를 broadcast하여 일관성을 보장한다.
TreeMaskMode 비교
| 모드 | 크기 | 용도 |
|---|---|---|
| FULL_MASK | seq_lens_sum * N + N^2 * bs |
범용, 모든 backend 지원 |
| QLEN_ONLY | N^2 * bs |
FlashInfer 최적화 |
| QLEN_ONLY_BITPACKING | N * bs (packed) |
메모리 최적화 |
N = num_verify_tokens, bs = batch_size
설계 근거
이진 트리 인코딩: next_token(첫 자식)과 next_sibling(다음 형제)으로 N-ary 트리를 2개의 포인터로 표현한다. 이는 CUDA 커널에서 효율적으로 순회 가능한 구조다.
Score 기반 가지치기: 전체 후보 중 num_draft_token - 1개만 선택하여 검증 비용을 제한한다. 점수가 낮은 가지는 자동으로 제거된다.
Grammar 통합: 구조화된 출력에서도 speculative decoding을 사용할 수 있도록, 트리 순회 중 grammar 상태를 추적하고 bitmask를 생성한다.
관련 포스트
- Speculative Decoding 개요
- EAGLE: 은닉 상태 기반 드래프트 모델
- N-gram Draft: 모델 프리 투기적 디코딩
- DFlash: Flash 기반 고속 드래프팅
참고
- SGLang spec_utils.py 소스
- SGLang eagle_utils.py 소스
- Miao et al., "SpecInfer: Accelerating Generative Large Language Model Serving with Tree-based Speculative Inference" (2024)
관련 포스트
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
- [sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서
- [sglang] NixlKVManager 성능 향상: 비동기 및 멀티스레드 KV 전송 도입
SGLang 의 다른글
- 이전글 [SGLang] EAGLE CUDA Graph: 드래프트 모델 가속
- 현재글 : [SGLang] Tree Search & Verification: 트리 기반 추측과 검증
- 다음글 [SGLang] Grammar Manager: 구조화된 출력 생성의 통합 관리
댓글