[SGLang] Speculative Decoding 개요: 원리와 구현 아키텍처
들어가며
LLM 추론에서 Autoregressive decoding은 한 번에 하나의 토큰만 생성하기 때문에 GPU 활용률이 낮다. Speculative Decoding은 작은 드래프트 모델로 여러 토큰을 빠르게 "추측"한 뒤, 타겟 모델이 이를 한 번에 검증하여 throughput을 높이는 기법이다. SGLang은 이 기법을 EAGLE, DFlash, N-gram 등 다양한 알고리즘으로 구현하며, 트리 기반 검증과 CUDA Graph 가속까지 통합한 완성도 높은 구현을 제공한다.
구조도
┌──────────────────────────────────────────────────────┐
│ SGLang Scheduler │
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌───────────┐ │
│ │ EAGLE Worker │ │DFlash Worker│ │NGRAM Worker│ │
│ └──────┬──────┘ └──────┬──────┘ └─────┬─────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌──────────────────────────────────────────────────┐│
│ │ BaseSpecWorker / TpModelWorker ││
│ │ ┌──────────┐ ┌──────────────┐ ││
│ │ │Draft Model│ │ Target Model │ ││
│ │ │(드래프트) │──▶│ (검증) │ ││
│ │ └──────────┘ └──────────────┘ ││
│ └──────────────────────────────────────────────────┘│
│ │ │
│ ▼ │
│ ┌──────────────┐ ┌───────────────┐ │
│ │ SpecInput │ │ SpecUtils │ │
│ │ (Draft/Verify│ │(트리 구축, │ │
│ │ 데이터) │ │ 검증 커널) │ │
│ └──────────────┘ └───────────────┘ │
└──────────────────────────────────────────────────────┘
Autoregressive vs Speculative Decoding 비교
| 항목 | Autoregressive | Speculative Decoding |
|---|---|---|
| 토큰 생성 | 1개/step | N개/step (드래프트) |
| GPU 활용률 | 낮음 (memory-bound) | 높음 (병렬 검증) |
| 모델 호출 | 타겟 모델 N회 | 드래프트 N회 + 타겟 1회 |
| 정확도 | 100% | 100% (rejection sampling) |
| 속도 향상 | 기준선 | 2-3x (acceptance rate 의존) |
Autoregressive:
Step 1 ──▶ Step 2 ──▶ Step 3 ──▶ Step 4 ──▶ Step 5
[tok1] [tok2] [tok3] [tok4] [tok5]
타겟모델 타겟모델 타겟모델 타겟모델 타겟모델
Speculative:
Draft Phase Verify Phase
┌──────────────┐ ┌──────────────────┐
│ 드래프트 모델 │ │ 타겟 모델 │
│ tok1→tok2→ │──▶│ 검증: tok1~tok5 │──▶ 수락된 토큰 출력
│ tok3→tok4→ │ │ (1회 forward) │
│ tok5 │ └──────────────────┘
└──────────────┘
핵심 코드 분석
1. 알고리즘 레지스트리: SpeculativeAlgorithm
SGLang은 모든 speculative decoding 알고리즘을 SpeculativeAlgorithm enum으로 관리한다. python/sglang/srt/speculative/spec_info.py에 정의되어 있다.
class SpeculativeAlgorithm(Enum):
DFLASH = auto()
EAGLE = auto()
EAGLE3 = auto()
STANDALONE = auto()
NGRAM = auto()
NONE = auto()
create_worker 메서드가 서버 설정에 따라 적절한 워커 클래스를 반환한다.
def create_worker(self, server_args: ServerArgs):
enable_overlap = not server_args.disable_overlap_schedule
if self.is_eagle() and server_args.enable_multi_layer_eagle:
if enable_overlap:
return MultiLayerEagleWorkerV2
return MultiLayerEagleWorker
elif self.is_eagle():
if enable_overlap:
return EAGLEWorkerV2
return EAGLEWorker
elif self.is_ngram():
return NGRAMWorker
enable_overlap에 따라 V1(동기)과 V2(비동기 오버랩) 워커를 구분하는 것이 특징이다.
2. 추상 인터페이스: BaseSpecWorker
python/sglang/srt/speculative/base_spec_worker.py에서 모든 speculative worker의 공통 인터페이스를 정의한다.
class BaseDraftWorker(ABC):
@abstractmethod
def draft():
pass
@abstractmethod
def draft_extend():
pass
class BaseSpecWorker(ABC):
@property
@abstractmethod
def target_worker(self) -> TpModelWorker:
pass
@property
@abstractmethod
def draft_worker(self) -> BaseDraftWorker:
pass
BaseDraftWorker는 드래프트 생성(draft)과 드래프트 확장(draft_extend)을 추상 메서드로 요구하고, BaseSpecWorker는 타겟/드래프트 워커 접근을 강제한다.
3. SpecInput 데이터 흐름
드래프트와 검증 단계 사이의 데이터는 SpecInput 추상 클래스를 통해 전달된다.
class SpecInput(ABC):
def __init__(self, spec_input_type: SpecInputType):
self.spec_input_type = spec_input_type
def is_draft_input(self) -> bool:
return self.spec_input_type in {
SpecInputType.EAGLE_DRAFT,
SpecInputType.DFLASH_DRAFT,
}
def is_verify_input(self) -> bool:
return self.spec_input_type in {
SpecInputType.EAGLE_VERIFY,
SpecInputType.DFLASH_VERIFY,
SpecInputType.NGRAM_VERIFY,
}
SpecInputType이 각 알고리즘의 입력 유형을 구분하여, attention backend에서 알고리즘별 처리를 분기할 수 있다.
4. 공통 유틸리티: spec_utils.py
트리 탐색, 토큰 선택, 시뮬레이션 등 알고리즘 공통 기능이 spec_utils.py에 모여 있다. 대표적으로 top-k 토큰 선택 함수는 트리 구조의 후보를 관리한다.
def select_top_k_tokens(i, topk_p, topk_index, hidden_states, scores, topk):
if i == 0:
input_ids = topk_index.flatten()
if hidden_states is not None:
hidden_states = hidden_states.repeat_interleave(topk, dim=0)
scores = topk_p
tree_info = (topk_p.unsqueeze(1), topk_index, ...)
else:
expand_scores = torch.mul(
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
)
topk_cs_p, topk_cs_index = fast_topk(
expand_scores.flatten(start_dim=1), topk, dim=-1
)
...
return input_ids, hidden_states, scores, tree_info
첫 번째 step에서는 단순히 top-k를 선택하고, 이후 step에서는 이전 scores와 현재 확률을 곱하여 누적 점수 기반으로 가지를 선택한다.
설계 근거
SGLang의 speculative decoding 설계에는 몇 가지 핵심 결정이 있다.
1. 알고리즘 교체 가능성: SpeculativeAlgorithm enum과 create_worker 팩토리 패턴으로 서버 인자 하나로 알고리즘을 교체할 수 있다. EAGLE, DFlash, N-gram 모두 같은 스케줄러 인터페이스를 공유한다.
2. 메모리 풀 공유: 드래프트와 타겟 워커가 req_to_token_pool과 token_to_kv_pool_allocator를 공유하여 KV 캐시 메모리를 절약한다.
3. V1/V2 분리: V1은 동기 실행, V2는 드래프트와 타겟 실행을 오버랩하여 파이프라인 효율을 높인다. disable_overlap_schedule 플래그로 제어한다.
4. Triton 커널 활용: assign_req_to_token_pool, create_extend_after_decode_spec_info 등 빈번한 인덱스 연산을 Triton JIT 커널로 구현하여 CPU-GPU 동기화를 최소화한다.
관련 포스트
- EAGLE: 은닉 상태 기반 드래프트 모델
- EAGLE v2: 개선된 드래프트 알고리즘
- Multi-Layer EAGLE: 다계층 드래프트
- N-gram Draft: 모델 프리 투기적 디코딩
- DFlash: Flash 기반 고속 드래프팅
- EAGLE CUDA Graph: 드래프트 모델 가속
- Tree Search & Verification: 트리 기반 추측과 검증
참고
- SGLang Speculative 소스 코드
- Leviathan et al., "Fast Inference from Transformers via Speculative Decoding" (2023)
- Chen et al., "Accelerating Large Language Model Decoding with Speculative Sampling" (2023)
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] FlashInfer + TensorRT-LLM MoE: 하이브리드 MoE 커널
- 현재글 : [SGLang] Speculative Decoding 개요: 원리와 구현 아키텍처
- 다음글 [SGLang] EAGLE: 은닉 상태 기반 드래프트 모델
댓글