본문으로 건너뛰기

[SGLang] Speculative Decoding 개요: 원리와 구현 아키텍처

들어가며

LLM 추론에서 Autoregressive decoding은 한 번에 하나의 토큰만 생성하기 때문에 GPU 활용률이 낮다. Speculative Decoding은 작은 드래프트 모델로 여러 토큰을 빠르게 "추측"한 뒤, 타겟 모델이 이를 한 번에 검증하여 throughput을 높이는 기법이다. SGLang은 이 기법을 EAGLE, DFlash, N-gram 등 다양한 알고리즘으로 구현하며, 트리 기반 검증과 CUDA Graph 가속까지 통합한 완성도 높은 구현을 제공한다.

구조도

┌──────────────────────────────────────────────────────┐
│                    SGLang Scheduler                   │
│                                                      │
│  ┌─────────────┐    ┌─────────────┐    ┌───────────┐ │
│  │ EAGLE Worker │    │DFlash Worker│    │NGRAM Worker│ │
│  └──────┬──────┘    └──────┬──────┘    └─────┬─────┘ │
│         │                  │                 │       │
│         ▼                  ▼                 ▼       │
│  ┌──────────────────────────────────────────────────┐│
│  │              BaseSpecWorker / TpModelWorker       ││
│  │  ┌──────────┐   ┌──────────────┐                 ││
│  │  │Draft Model│   │ Target Model │                 ││
│  │  │(드래프트) │──▶│  (검증)      │                 ││
│  │  └──────────┘   └──────────────┘                 ││
│  └──────────────────────────────────────────────────┘│
│         │                                            │
│         ▼                                            │
│  ┌──────────────┐  ┌───────────────┐                 │
│  │  SpecInput   │  │ SpecUtils     │                 │
│  │ (Draft/Verify│  │(트리 구축,    │                 │
│  │  데이터)     │  │ 검증 커널)    │                 │
│  └──────────────┘  └───────────────┘                 │
└──────────────────────────────────────────────────────┘

Autoregressive vs Speculative Decoding 비교

항목 Autoregressive Speculative Decoding
토큰 생성 1개/step N개/step (드래프트)
GPU 활용률 낮음 (memory-bound) 높음 (병렬 검증)
모델 호출 타겟 모델 N회 드래프트 N회 + 타겟 1회
정확도 100% 100% (rejection sampling)
속도 향상 기준선 2-3x (acceptance rate 의존)
Autoregressive:
  Step 1 ──▶ Step 2 ──▶ Step 3 ──▶ Step 4 ──▶ Step 5
  [tok1]     [tok2]     [tok3]     [tok4]     [tok5]
  타겟모델   타겟모델    타겟모델    타겟모델    타겟모델

Speculative:
  Draft Phase          Verify Phase
  ┌──────────────┐    ┌──────────────────┐
  │ 드래프트 모델  │    │    타겟 모델      │
  │ tok1→tok2→   │──▶│ 검증: tok1~tok5  │──▶ 수락된 토큰 출력
  │ tok3→tok4→   │    │ (1회 forward)    │
  │ tok5          │    └──────────────────┘
  └──────────────┘

핵심 코드 분석

1. 알고리즘 레지스트리: SpeculativeAlgorithm

SGLang은 모든 speculative decoding 알고리즘을 SpeculativeAlgorithm enum으로 관리한다. python/sglang/srt/speculative/spec_info.py에 정의되어 있다.

class SpeculativeAlgorithm(Enum):
    DFLASH = auto()
    EAGLE = auto()
    EAGLE3 = auto()
    STANDALONE = auto()
    NGRAM = auto()
    NONE = auto()

create_worker 메서드가 서버 설정에 따라 적절한 워커 클래스를 반환한다.

def create_worker(self, server_args: ServerArgs):
    enable_overlap = not server_args.disable_overlap_schedule
    if self.is_eagle() and server_args.enable_multi_layer_eagle:
        if enable_overlap:
            return MultiLayerEagleWorkerV2
        return MultiLayerEagleWorker
    elif self.is_eagle():
        if enable_overlap:
            return EAGLEWorkerV2
        return EAGLEWorker
    elif self.is_ngram():
        return NGRAMWorker

enable_overlap에 따라 V1(동기)과 V2(비동기 오버랩) 워커를 구분하는 것이 특징이다.

2. 추상 인터페이스: BaseSpecWorker

python/sglang/srt/speculative/base_spec_worker.py에서 모든 speculative worker의 공통 인터페이스를 정의한다.

class BaseDraftWorker(ABC):
    @abstractmethod
    def draft():
        pass

    @abstractmethod
    def draft_extend():
        pass

class BaseSpecWorker(ABC):
    @property
    @abstractmethod
    def target_worker(self) -> TpModelWorker:
        pass

    @property
    @abstractmethod
    def draft_worker(self) -> BaseDraftWorker:
        pass

BaseDraftWorker는 드래프트 생성(draft)과 드래프트 확장(draft_extend)을 추상 메서드로 요구하고, BaseSpecWorker는 타겟/드래프트 워커 접근을 강제한다.

3. SpecInput 데이터 흐름

드래프트와 검증 단계 사이의 데이터는 SpecInput 추상 클래스를 통해 전달된다.

class SpecInput(ABC):
    def __init__(self, spec_input_type: SpecInputType):
        self.spec_input_type = spec_input_type

    def is_draft_input(self) -> bool:
        return self.spec_input_type in {
            SpecInputType.EAGLE_DRAFT,
            SpecInputType.DFLASH_DRAFT,
        }

    def is_verify_input(self) -> bool:
        return self.spec_input_type in {
            SpecInputType.EAGLE_VERIFY,
            SpecInputType.DFLASH_VERIFY,
            SpecInputType.NGRAM_VERIFY,
        }

SpecInputType이 각 알고리즘의 입력 유형을 구분하여, attention backend에서 알고리즘별 처리를 분기할 수 있다.

4. 공통 유틸리티: spec_utils.py

트리 탐색, 토큰 선택, 시뮬레이션 등 알고리즘 공통 기능이 spec_utils.py에 모여 있다. 대표적으로 top-k 토큰 선택 함수는 트리 구조의 후보를 관리한다.

def select_top_k_tokens(i, topk_p, topk_index, hidden_states, scores, topk):
    if i == 0:
        input_ids = topk_index.flatten()
        if hidden_states is not None:
            hidden_states = hidden_states.repeat_interleave(topk, dim=0)
        scores = topk_p
        tree_info = (topk_p.unsqueeze(1), topk_index, ...)
    else:
        expand_scores = torch.mul(
            scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
        )
        topk_cs_p, topk_cs_index = fast_topk(
            expand_scores.flatten(start_dim=1), topk, dim=-1
        )
        ...
    return input_ids, hidden_states, scores, tree_info

첫 번째 step에서는 단순히 top-k를 선택하고, 이후 step에서는 이전 scores와 현재 확률을 곱하여 누적 점수 기반으로 가지를 선택한다.

설계 근거

SGLang의 speculative decoding 설계에는 몇 가지 핵심 결정이 있다.

1. 알고리즘 교체 가능성: SpeculativeAlgorithm enum과 create_worker 팩토리 패턴으로 서버 인자 하나로 알고리즘을 교체할 수 있다. EAGLE, DFlash, N-gram 모두 같은 스케줄러 인터페이스를 공유한다.

2. 메모리 풀 공유: 드래프트와 타겟 워커가 req_to_token_pooltoken_to_kv_pool_allocator를 공유하여 KV 캐시 메모리를 절약한다.

3. V1/V2 분리: V1은 동기 실행, V2는 드래프트와 타겟 실행을 오버랩하여 파이프라인 효율을 높인다. disable_overlap_schedule 플래그로 제어한다.

4. Triton 커널 활용: assign_req_to_token_pool, create_extend_after_decode_spec_info 등 빈번한 인덱스 연산을 Triton JIT 커널로 구현하여 CPU-GPU 동기화를 최소화한다.

관련 포스트

참고

  • SGLang Speculative 소스 코드
  • Leviathan et al., "Fast Inference from Transformers via Speculative Decoding" (2023)
  • Chen et al., "Accelerating Large Language Model Decoding with Speculative Sampling" (2023)

댓글

관련 포스트

SGLang 의 다른글