본문으로 건너뛰기

[SGLang] EAGLE: 은닉 상태 기반 드래프트 모델

들어가며

EAGLE(Extrapolation Algorithm for Greater Language-model Efficiency)은 타겟 모델의 은닉 상태(hidden states)를 드래프트 모델의 입력으로 직접 활용하는 speculative decoding 기법이다. 독립 드래프트 모델은 타겟과 별도로 학습되어 예측 정확도가 낮지만, EAGLE은 타겟의 내부 표현을 그대로 가져와 더 높은 acceptance rate를 달성한다. SGLang의 EAGLEWorker는 이 알고리즘을 드래프트-검증-캐시관리까지 완전하게 구현한다.

구조도

┌─────────────────────────────────────────────────┐
│                 EAGLEWorker                       │
│                                                   │
│  ┌──────────────┐      ┌──────────────────┐      │
│  │ Target Model  │      │   Draft Model     │      │
│  │ (target_worker│      │ (draft_model_runner│      │
│  │  .model_runner│      │  shares embed/head)│      │
│  └───────┬──────┘      └────────┬─────────┘      │
│          │                      │                  │
│          │  hidden_states       │  topk_p,         │
│          │  + verified_id       │  topk_index       │
│          ▼                      ▼                  │
│  ┌──────────────┐      ┌──────────────────┐      │
│  │EagleDraftInput│─────▶│   Draft Phase     │      │
│  │ .hidden_states│      │ select_top_k ×N   │      │
│  │ .verified_id  │      │ build_tree         │      │
│  │ .topk_p/index │      └────────┬─────────┘      │
│  └──────────────┘               │                  │
│                                  ▼                  │
│                        ┌──────────────────┐        │
│                        │ EagleVerifyInput  │        │
│                        │ .draft_token      │        │
│                        │ .custom_mask      │        │
│                        │ .retrive_index    │        │
│                        └────────┬─────────┘        │
│                                  │                  │
│                                  ▼                  │
│                        ┌──────────────────┐        │
│                        │  Verify Phase     │        │
│                        │ target forward    │        │
│                        │ + tree_verify     │        │
│                        └──────────────────┘        │
└─────────────────────────────────────────────────┘

핵심 코드 분석

1. EAGLEWorker 초기화: 임베딩 공유

python/sglang/srt/speculative/eagle_worker.py에서 드래프트 모델은 타겟 모델의 embedding과 lm_head를 공유한다.

embed, head = self.target_worker.model_runner.model.get_embed_and_head()

if self.speculative_algorithm.is_eagle3():
    if (hasattr(self.draft_model_runner.model, "load_lm_head_from_target")
        and self.draft_model_runner.model.load_lm_head_from_target):
        self.draft_model_runner.model.set_embed_and_head(embed, head)
    else:
        self.draft_model_runner.model.set_embed(embed)
else:
    if self.hot_token_id is not None:
        head = head.clone()
        self.hot_token_id = self.hot_token_id.to(head.device)
        head.data = head.data[self.hot_token_id]
    self.draft_model_runner.model.set_embed_and_head(embed, head)

hot_token_id가 설정되면 lm_head를 고빈도 토큰 서브셋으로 축소하여 드래프트 모델의 출력 차원을 줄인다. 이는 드래프트 forward 속도를 높이는 핵심 최적화다.

2. EagleDraftInput: 은닉 상태 전달

python/sglang/srt/speculative/eagle_info.py에 정의된 EagleDraftInput은 검증 후 다음 드래프트 단계로 전달되는 데이터를 담는다.

@dataclass
class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
    topk_p: torch.Tensor = None          # (b, topk)
    topk_index: torch.Tensor = None      # (b, topk)
    hidden_states: torch.Tensor = None   # (b, hidden_size)
    capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
    verified_id: torch.Tensor = None     # (b,)
    accept_length: torch.Tensor = None   # (b,)
    accept_length_cpu: List[int] = None

hidden_states가 타겟 모델에서 추출된 은닉 상태이며, 이것이 드래프트 모델의 조건 입력으로 사용된다. topk_ptopk_index는 이전 드래프트 step의 top-k 확률과 토큰 인덱스다.

3. 드래프트 생성: 다단계 top-k 선택

드래프트 단계에서는 select_top_k_tokens를 반복 호출하여 트리 구조의 후보를 생성한다.

def draft(self, batch: ScheduleBatch):
    spec_info = batch.spec_info
    topk_p, topk_index, hidden_states = (
        spec_info.topk_p, spec_info.topk_index, spec_info.hidden_states,
    )

    scores = None
    score_list, token_list, parents_list = [], [], []
    for i in range(self.speculative_num_steps):
        input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
            i, topk_p, topk_index, hidden_states, scores, self.topk
        )
        score_list.append(tree_info[0])
        token_list.append(tree_info[1])
        parents_list.append(tree_info[2])

        # Draft model forward
        logits_output = self.draft_model_runner.forward(forward_batch)
        probs = torch.softmax(logits_output.next_token_logits, dim=-1)
        topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)

매 step마다 드래프트 모델을 forward하고, 결과에서 top-k를 뽑아 다음 step의 입력으로 사용한다. scores는 누적 확률을 추적하여 최종적으로 가장 유망한 후보를 선택한다.

4. 트리 구축과 검증 입력 생성

드래프트 결과를 organize_draft_results로 정리하고 build_tree_kernel_efficient로 트리 마스크를 구성한다.

parent_list, top_scores_index, draft_tokens = organize_draft_results(
    score_list, token_list, parents_list, self.speculative_num_draft_tokens
)

tree_mask, position, retrive_index, retrive_next_token, retrive_next_sibling, draft_tokens = \
    build_tree_kernel_efficient(
        spec_info.verified_id, parent_list, top_scores_index, draft_tokens,
        batch.seq_lens, batch.seq_lens_sum,
        self.topk, self.speculative_num_steps, self.speculative_num_draft_tokens,
    )

return EagleVerifyInput(
    draft_token=draft_tokens, custom_mask=tree_mask,
    positions=position, retrive_index=retrive_index, ...
)

retrive_next_tokenretrive_next_sibling은 트리를 DFS로 순회하기 위한 인덱스다.

5. 검증: Greedy vs Sampling

EagleVerifyInput.verify()에서 타겟 모델의 logits와 드래프트 토큰을 비교한다.

if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
    target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
    predict, accept_index, accept_length = verify_tree_greedy_func(
        predicts=predict, accept_index=accept_index,
        candidates=candidates, retrive_index=self.retrive_index,
        retrive_next_token=self.retrive_next_token,
        retrive_next_sibling=self.retrive_next_sibling,
        target_predict=target_predict, topk=self.topk,
    )
else:
    target_probs = F.softmax(
        logits_output.next_token_logits / expanded_temperature, dim=-1
    )
    tree_speculative_sampling_target_only(
        predicts=predict, accept_index=accept_index,
        candidates=candidates, target_probs=target_probs,
        draft_probs=draft_probs, ...
    )

Greedy 모드에서는 verify_tree_greedy로 argmax 비교만 수행하고, sampling 모드에서는 tree_speculative_sampling_target_only CUDA 커널을 사용하여 rejection sampling을 병렬로 처리한다.

6. KV 캐시 해제

검증 후 거부된 토큰의 KV 캐시를 즉시 해제한다.

evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])

수락된 토큰만 남기고 나머지 캐시 슬롯을 반환하여 메모리 효율을 유지한다.

설계 근거

은닉 상태 활용의 효과: 독립 드래프트 모델은 타겟의 분포를 근사할 뿐이지만, EAGLE은 타겟의 마지막 hidden state를 직접 받으므로 동일한 파라미터 수 대비 훨씬 높은 acceptance rate를 달성한다.

트리 구조 탐색: 단순 chain이 아닌 tree 구조로 후보를 생성하면 하나의 검증 forward로 여러 경로를 동시에 평가할 수 있어 throughput이 높아진다.

임베딩/헤드 공유: 드래프트 모델이 타겟의 embedding과 lm_head를 직접 참조하여 별도 메모리 할당 없이 동일한 토큰 공간을 공유한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글