[SGLang] EAGLE: 은닉 상태 기반 드래프트 모델
들어가며
EAGLE(Extrapolation Algorithm for Greater Language-model Efficiency)은 타겟 모델의 은닉 상태(hidden states)를 드래프트 모델의 입력으로 직접 활용하는 speculative decoding 기법이다. 독립 드래프트 모델은 타겟과 별도로 학습되어 예측 정확도가 낮지만, EAGLE은 타겟의 내부 표현을 그대로 가져와 더 높은 acceptance rate를 달성한다. SGLang의 EAGLEWorker는 이 알고리즘을 드래프트-검증-캐시관리까지 완전하게 구현한다.
구조도
┌─────────────────────────────────────────────────┐
│ EAGLEWorker │
│ │
│ ┌──────────────┐ ┌──────────────────┐ │
│ │ Target Model │ │ Draft Model │ │
│ │ (target_worker│ │ (draft_model_runner│ │
│ │ .model_runner│ │ shares embed/head)│ │
│ └───────┬──────┘ └────────┬─────────┘ │
│ │ │ │
│ │ hidden_states │ topk_p, │
│ │ + verified_id │ topk_index │
│ ▼ ▼ │
│ ┌──────────────┐ ┌──────────────────┐ │
│ │EagleDraftInput│─────▶│ Draft Phase │ │
│ │ .hidden_states│ │ select_top_k ×N │ │
│ │ .verified_id │ │ build_tree │ │
│ │ .topk_p/index │ └────────┬─────────┘ │
│ └──────────────┘ │ │
│ ▼ │
│ ┌──────────────────┐ │
│ │ EagleVerifyInput │ │
│ │ .draft_token │ │
│ │ .custom_mask │ │
│ │ .retrive_index │ │
│ └────────┬─────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────┐ │
│ │ Verify Phase │ │
│ │ target forward │ │
│ │ + tree_verify │ │
│ └──────────────────┘ │
└─────────────────────────────────────────────────┘
핵심 코드 분석
1. EAGLEWorker 초기화: 임베딩 공유
python/sglang/srt/speculative/eagle_worker.py에서 드래프트 모델은 타겟 모델의 embedding과 lm_head를 공유한다.
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
if self.speculative_algorithm.is_eagle3():
if (hasattr(self.draft_model_runner.model, "load_lm_head_from_target")
and self.draft_model_runner.model.load_lm_head_from_target):
self.draft_model_runner.model.set_embed_and_head(embed, head)
else:
self.draft_model_runner.model.set_embed(embed)
else:
if self.hot_token_id is not None:
head = head.clone()
self.hot_token_id = self.hot_token_id.to(head.device)
head.data = head.data[self.hot_token_id]
self.draft_model_runner.model.set_embed_and_head(embed, head)
hot_token_id가 설정되면 lm_head를 고빈도 토큰 서브셋으로 축소하여 드래프트 모델의 출력 차원을 줄인다. 이는 드래프트 forward 속도를 높이는 핵심 최적화다.
2. EagleDraftInput: 은닉 상태 전달
python/sglang/srt/speculative/eagle_info.py에 정의된 EagleDraftInput은 검증 후 다음 드래프트 단계로 전달되는 데이터를 담는다.
@dataclass
class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
topk_p: torch.Tensor = None # (b, topk)
topk_index: torch.Tensor = None # (b, topk)
hidden_states: torch.Tensor = None # (b, hidden_size)
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
verified_id: torch.Tensor = None # (b,)
accept_length: torch.Tensor = None # (b,)
accept_length_cpu: List[int] = None
hidden_states가 타겟 모델에서 추출된 은닉 상태이며, 이것이 드래프트 모델의 조건 입력으로 사용된다. topk_p와 topk_index는 이전 드래프트 step의 top-k 확률과 토큰 인덱스다.
3. 드래프트 생성: 다단계 top-k 선택
드래프트 단계에서는 select_top_k_tokens를 반복 호출하여 트리 구조의 후보를 생성한다.
def draft(self, batch: ScheduleBatch):
spec_info = batch.spec_info
topk_p, topk_index, hidden_states = (
spec_info.topk_p, spec_info.topk_index, spec_info.hidden_states,
)
scores = None
score_list, token_list, parents_list = [], [], []
for i in range(self.speculative_num_steps):
input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
i, topk_p, topk_index, hidden_states, scores, self.topk
)
score_list.append(tree_info[0])
token_list.append(tree_info[1])
parents_list.append(tree_info[2])
# Draft model forward
logits_output = self.draft_model_runner.forward(forward_batch)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
매 step마다 드래프트 모델을 forward하고, 결과에서 top-k를 뽑아 다음 step의 입력으로 사용한다. scores는 누적 확률을 추적하여 최종적으로 가장 유망한 후보를 선택한다.
4. 트리 구축과 검증 입력 생성
드래프트 결과를 organize_draft_results로 정리하고 build_tree_kernel_efficient로 트리 마스크를 구성한다.
parent_list, top_scores_index, draft_tokens = organize_draft_results(
score_list, token_list, parents_list, self.speculative_num_draft_tokens
)
tree_mask, position, retrive_index, retrive_next_token, retrive_next_sibling, draft_tokens = \
build_tree_kernel_efficient(
spec_info.verified_id, parent_list, top_scores_index, draft_tokens,
batch.seq_lens, batch.seq_lens_sum,
self.topk, self.speculative_num_steps, self.speculative_num_draft_tokens,
)
return EagleVerifyInput(
draft_token=draft_tokens, custom_mask=tree_mask,
positions=position, retrive_index=retrive_index, ...
)
retrive_next_token과 retrive_next_sibling은 트리를 DFS로 순회하기 위한 인덱스다.
5. 검증: Greedy vs Sampling
EagleVerifyInput.verify()에서 타겟 모델의 logits와 드래프트 토큰을 비교한다.
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
predict, accept_index, accept_length = verify_tree_greedy_func(
predicts=predict, accept_index=accept_index,
candidates=candidates, retrive_index=self.retrive_index,
retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling,
target_predict=target_predict, topk=self.topk,
)
else:
target_probs = F.softmax(
logits_output.next_token_logits / expanded_temperature, dim=-1
)
tree_speculative_sampling_target_only(
predicts=predict, accept_index=accept_index,
candidates=candidates, target_probs=target_probs,
draft_probs=draft_probs, ...
)
Greedy 모드에서는 verify_tree_greedy로 argmax 비교만 수행하고, sampling 모드에서는 tree_speculative_sampling_target_only CUDA 커널을 사용하여 rejection sampling을 병렬로 처리한다.
6. KV 캐시 해제
검증 후 거부된 토큰의 KV 캐시를 즉시 해제한다.
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
수락된 토큰만 남기고 나머지 캐시 슬롯을 반환하여 메모리 효율을 유지한다.
설계 근거
은닉 상태 활용의 효과: 독립 드래프트 모델은 타겟의 분포를 근사할 뿐이지만, EAGLE은 타겟의 마지막 hidden state를 직접 받으므로 동일한 파라미터 수 대비 훨씬 높은 acceptance rate를 달성한다.
트리 구조 탐색: 단순 chain이 아닌 tree 구조로 후보를 생성하면 하나의 검증 forward로 여러 경로를 동시에 평가할 수 있어 throughput이 높아진다.
임베딩/헤드 공유: 드래프트 모델이 타겟의 embedding과 lm_head를 직접 참조하여 별도 메모리 할당 없이 동일한 토큰 공간을 공유한다.
관련 포스트
- Speculative Decoding 개요
- EAGLE v2: 개선된 드래프트 알고리즘
- Multi-Layer EAGLE: 다계층 드래프트
- EAGLE CUDA Graph: 드래프트 모델 가속
- Tree Search & Verification
참고
- SGLang EAGLE Worker 소스
- SGLang EAGLE Info 소스
- Li et al., "EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty" (2024)
관련 포스트
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [sglang] SGLang EAGLE 디코딩 최적화: 불필요한 Softmax 연산 제거로 성능 향상
- [논문리뷰] Your Embedding Model is SMARTer Than You Think
- [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
SGLang 의 다른글
- 이전글 [SGLang] Speculative Decoding 개요: 원리와 구현 아키텍처
- 현재글 : [SGLang] EAGLE: 은닉 상태 기반 드래프트 모델
- 다음글 [SGLang] EAGLE v2: 개선된 드래프트 알고리즘
댓글