본문으로 건너뛰기

[SGLang] Tree Search & Verification: 트리 기반 추측과 검증

들어가며

Speculative decoding에서 드래프트 토큰을 트리 구조로 배치하면 하나의 타겟 모델 forward로 여러 후보 경로를 동시에 검증할 수 있다. 선형 체인(topk=1)은 하나의 경로만 검증하지만, 트리(topk>1)는 분기된 경로를 병렬로 평가하여 최소 하나 이상의 긴 수락 경로를 확보할 확률을 높인다. SGLang의 spec_utils.pyeagle_utils.py에 이 트리 구축/검증 로직이 구현되어 있다.

구조도

                      [root: verified_id]
                     /         |          \
                [tok_A]     [tok_B]     [tok_C]     ← step 0 (topk=3)
               /    \       /    \       /    \
          [tok_D] [tok_E] [tok_F] [tok_G] [tok_H] [tok_I]  ← step 1
            |       |       |       |       |       |
          [...]   [...]   [...]   [...]   [...]   [...]     ← step 2

  retrive_next_token:  각 노드의 첫 번째 자식 인덱스
  retrive_next_sibling: 각 노드의 다음 형제 인덱스
  retrive_index:        검증 결과에서 수락 경로 추적용 인덱스

핵심 코드 분석

1. 트리 구축: organize_draft_results

python/sglang/srt/speculative/eagle_utils.py에서 드래프트 결과를 정리한다.

def organize_draft_results(score_list, token_list, parents_list, num_draft_token):
    score_list = torch.cat(score_list, dim=1).flatten(1)
    ss_token_list = torch.cat(token_list, dim=1)
    top_scores = torch.topk(score_list, num_draft_token - 1, dim=-1)
    top_scores_index = top_scores.indices
    top_scores_index = torch.sort(top_scores_index).values
    draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)

    if len(parents_list) > 1:
        parent_list = torch.cat(parents_list[:-1], dim=1)
    else:
        parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)

    return parent_list, top_scores_index, draft_tokens

모든 step의 score를 합쳐 상위 num_draft_token - 1개를 선택한다. -1인 이유는 root(verified_id)가 이미 하나를 차지하기 때문이다. top_scores_index를 정렬하여 트리 구조를 위치 기반으로 일관되게 유지한다.

2. 트리 마스크 생성: build_tree_kernel_efficient

def build_tree_kernel_efficient(
    verified_id, parent_list, top_scores_index, draft_tokens,
    seq_lens, seq_lens_sum, topk, spec_steps, num_verify_tokens,
    tree_mask_mode=TreeMaskMode.FULL_MASK, ...
):
    draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()

    if tree_mask_mode == TreeMaskMode.FULL_MASK:
        tree_mask = torch.full(
            (seq_lens_sum * num_verify_tokens + num_verify_tokens * num_verify_tokens * bs,),
            True, dtype=torch.bool, device=device,
        )
    elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
        tree_mask = torch.zeros(
            (num_verify_tokens * bs,), dtype=packed_dtypes[packed_dtype_idx], device=device,
        )

트리 마스크는 3가지 모드를 지원한다:

  • FULL_MASK: prefix + draft 영역 전체를 포함하는 완전 마스크
  • QLEN_ONLY: draft 영역만 포함하는 축소 마스크 (FlashInfer 호환)
  • QLEN_ONLY_BITPACKING: bitwise packing으로 메모리 절약

sgl_build_tree_kernel_efficient CUDA 커널이 parent_list에서 트리 마스크, position, retrieve 인덱스를 한 번에 계산한다.

3. DFS 기반 Grammar 트리 순회

구조화된 출력(grammar)을 지원하기 위해 DFS로 트리를 순회한다.

def traverse_tree(retrieve_next_token, retrieve_next_sibling, draft_tokens, grammar, allocate_token_bitmask, vocab_size):
    def dfs(curr, retrieve_next_token, retrieve_next_sibling, parent_pos):
        if curr == 0:
            accepted = True
        else:
            parent_bitmask = allocate_token_bitmask[parent_pos]
            curr_token_id = draft_tokens[curr]
            accepted = (parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))) != 0

        if accepted:
            if curr != 0:
                grammar.accept_token(draft_tokens[curr])
            if not grammar.is_terminated():
                grammar.fill_vocab_mask(allocate_token_bitmask, curr)
                if retrieve_next_token[curr] != -1:
                    dfs(retrieve_next_token[curr], ...)
            if curr != 0:
                grammar.rollback(1)

        if retrieve_next_sibling[curr] != -1:
            dfs(retrieve_next_sibling[curr], ..., parent_pos)

    dfs(0, retrieve_next_token, retrieve_next_sibling, -1)

retrieve_next_token이 자식, retrieve_next_sibling이 형제를 가리키는 이진 트리 표현을 사용한다. Grammar에서 수락된 토큰은 accept_token으로 상태를 전진시키고, 거부 시 rollback으로 되돌린다. 각 위치에서 fill_vocab_mask로 다음 허용 토큰의 bitmask를 생성한다.

4. Greedy 트리 검증

eagle_utils.pyverify_tree_greedy_func는 CUDA 커널 기반 greedy 검증을 수행한다.

predict, accept_index, accept_length = verify_tree_greedy_func(
    predicts=predict,           # 출력: 각 위치의 예측 토큰
    accept_index=accept_index,  # 출력: 수락된 인덱스
    accept_token_num=accept_length,
    candidates=candidates,      # 드래프트 토큰
    retrive_index=self.retrive_index,
    retrive_next_token=self.retrive_next_token,
    retrive_next_sibling=self.retrive_next_sibling,
    target_predict=target_predict,  # 타겟 모델의 argmax
    topk=self.topk,
)

retrive_next_tokenretrive_next_sibling을 따라 트리를 순회하며, 타겟 예측과 드래프트 토큰이 일치하면 수락하고 자식으로 내려간다. 불일치 시 형제로 이동한다.

5. Sampling 기반 트리 검증

non-greedy 샘플링 시에는 rejection sampling을 사용한다.

tree_speculative_sampling_target_only(
    predicts=predict,
    accept_index=accept_index,
    accept_token_num=accept_length,
    candidates=candidates,
    retrive_index=self.retrive_index,
    retrive_next_token=self.retrive_next_token,
    retrive_next_sibling=self.retrive_next_sibling,
    uniform_samples=coins,
    uniform_samples_for_final_sampling=coins_for_final_sampling,
    target_probs=target_probs,
    draft_probs=draft_probs,
    threshold_single=server_args.speculative_accept_threshold_single,
    threshold_acc=server_args.speculative_accept_threshold_acc,
    deterministic=True,
)

threshold_singlethreshold_acc가 단일 토큰 및 누적 수락 임계값을 제어한다. uniform_samples는 rejection sampling용 난수, coins_for_final_sampling은 마지막 bonus 토큰 샘플링용이다.

6. 시뮬레이션 모드

벤치마킹을 위한 acceptance length 시뮬레이션도 지원한다.

def generate_simulated_accept_index(accept_index, predict, accept_length, bs, spec_steps,
                                     simulate_acc_len, simulate_acc_method):
    if simulate_acc_method == "match-expected":
        lower = int(simulate_acc_len // 1)
        upper = lower + 1 if lower < spec_steps + 1 else lower
        weight_upper = simulate_acc_len - lower
        probs = torch.tensor([1 - weight_upper, weight_upper])
        sampled_index = torch.multinomial(probs, num_samples=1)
        simulate_acc_len = lower if sampled_index == 0 else upper

SGLANG_SIMULATE_ACC_LEN 환경변수로 원하는 평균 acceptance length를 설정하여, 실제 모델 없이도 speculative decoding의 스케줄링 성능을 테스트할 수 있다.

7. TP 동기화

멀티 GPU에서 검증 결과가 일관되도록 broadcast한다.

tp_group = get_attention_tp_group() if is_dp_attention_enabled() else get_tp_group()
if tp_group.world_size > 1:
    tp_group.broadcast(predict, src=0)
    tp_group.broadcast(accept_index, src=0)
    tp_group.broadcast(accept_length, src=0)

softmax의 floating-point 비결정성으로 인해 GPU마다 다른 토큰을 샘플링할 수 있으므로, rank 0의 결과를 broadcast하여 일관성을 보장한다.

TreeMaskMode 비교

모드 크기 용도
FULL_MASK seq_lens_sum * N + N^2 * bs 범용, 모든 backend 지원
QLEN_ONLY N^2 * bs FlashInfer 최적화
QLEN_ONLY_BITPACKING N * bs (packed) 메모리 최적화

N = num_verify_tokens, bs = batch_size

설계 근거

이진 트리 인코딩: next_token(첫 자식)과 next_sibling(다음 형제)으로 N-ary 트리를 2개의 포인터로 표현한다. 이는 CUDA 커널에서 효율적으로 순회 가능한 구조다.

Score 기반 가지치기: 전체 후보 중 num_draft_token - 1개만 선택하여 검증 비용을 제한한다. 점수가 낮은 가지는 자동으로 제거된다.

Grammar 통합: 구조화된 출력에서도 speculative decoding을 사용할 수 있도록, 트리 순회 중 grammar 상태를 추적하고 bitmask를 생성한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글