[vLLM] EAGLE: 은닉 상태 기반 드래프트로 Speculative Decoding을 강화하다
들어가며
일반적인 speculative decoding에서 드래프트 모델은 타겟 모델과 독립적으로 토큰을 예측한다. 하지만 타겟 모델이 이미 계산한 은닉 상태(hidden states)에는 풍부한 정보가 담겨 있다. EAGLE은 이 은닉 상태를 드래프트 모델의 입력으로 직접 전달하여 드래프트 정확도를 대폭 향상시킨다. 독립 드래프트 모델 대비 수락률이 크게 높아져 실질적인 가속 효과가 더 크다.
- 논문: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty (arxiv 2409.12191)
- 공식 문서: https://docs.vllm.ai
공식 문서
vLLM 공식 문서: EAGLE Speculative Decoding
핵심 구조/코드 분석
EAGLE Proposer의 핵심 차이
vllm/v1/spec_decode/eagle.py에서 EAGLE의 핵심은 pass_hidden_states_to_model=True이다:
class SpecDecodeBaseProposer:
def __init__(self, vllm_config, device,
pass_hidden_states_to_model, ...):
self.pass_hidden_states_to_model = pass_hidden_states_to_model
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = (
self.draft_model_config.get_inputs_embeds_size()
)
독립 드래프트 모델(DraftModelProposer)은 pass_hidden_states_to_model=False로 생성되지만, EAGLE는 True로 설정된다. 이 플래그 하나가 전체 데이터 흐름을 바꾼다.
은닉 상태 버퍼 관리
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype, device=device
)
타겟 모델의 마지막 레이어 은닉 상태가 이 버퍼에 저장된다. EAGLE 드래프트 모델은 이전 토큰의 임베딩과 이 은닉 상태를 결합하여 입력으로 사용한다.
병렬 드래프팅과 마스킹
EAGLE은 병렬 드래프팅(parallel drafting)을 지원한다:
self.parallel_drafting = self.speculative_config.parallel_drafting
if self.parallel_drafting:
self._init_parallel_drafting_params()
def _init_parallel_drafting_params(self):
model_hf_config = self.draft_model_config.hf_config
# 마스크 토큰 ID 설정
if hasattr(model_hf_config, "pard_token"):
self.parallel_drafting_token_id = model_hf_config.pard_token
elif hasattr(model_hf_config, "ptd_token_id"):
self.parallel_drafting_token_id = model_hf_config.ptd_token_id
if self.pass_hidden_states_to_model:
self.parallel_drafting_hidden_state_tensor = torch.empty(
self.hidden_size, dtype=self.dtype, device=self.device
)
병렬 드래프팅에서는 모든 speculative 토큰을 한 번의 forward로 생성한다. 아직 생성되지 않은 위치에는 parallel_drafting_token_id(마스크 토큰)를 넣고, 해당 위치의 은닉 상태도 별도 텐서(parallel_drafting_hidden_state_tensor)로 채운다.
거부된 토큰 처리
self.is_rejected_token_mask = torch.zeros(
(self.max_num_tokens,), dtype=torch.bool, device=device
)
self.is_masked_token_mask = torch.zeros(
(self.max_num_tokens,), dtype=torch.bool, device=device
)
타겟 모델이 검증(verify) 후 거부된 토큰은 is_rejected_token_mask로 표시된다. 이 마스크에 따라 KV 캐시의 slot_mapping이 패딩 슬롯(PADDING_SLOT_ID)으로 교체되어, 거부된 토큰의 KV가 캐시를 오염시키지 않는다.
슬롯 매핑과 KV 캐시 관리
EAGLE의 복잡한 부분은 드래프트 토큰의 KV 캐시 관리이다:
from vllm.v1.spec_decode.utils import (
PADDING_SLOT_ID,
compute_new_slot_mapping,
copy_and_expand_eagle_inputs_kernel,
eagle_prepare_inputs_padded_kernel,
eagle_prepare_next_token_padded_kernel,
eagle_step_update_slot_mapping_and_metadata,
extend_all_queries_by_N,
)
self._slot_mapping_buffer = torch.zeros(
self.max_positions, dtype=torch.int64, device=device,
)
드래프트 토큰은 임시로 KV 블록에 기록되지만, 거부되면 해당 슬롯이 무효화된다. eagle_step_update_slot_mapping_and_metadata 커널이 매 스텝마다 이 갱신을 수행한다.
트리 구조 드래프팅 상세
self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(
spec_token_tree
)
tree_depth = len(self.tree_choices[-1])
# 레벨별 속성 사전 계산
num_drafts_per_level = [0] * tree_depth
for node in self.tree_choices:
num_drafts_per_level[len(node) - 1] += 1
self.cu_drafts_per_level = [num_drafts_per_level[0]]
self.child_drafts_per_level = [num_drafts_per_level[0]]
for level in range(1, tree_depth):
self.cu_drafts_per_level.append(
self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
)
self.child_drafts_per_level.append(
num_drafts_per_level[level] // num_drafts_per_level[level - 1]
)
EAGLE은 트리 구조로 여러 가지를 동시에 탐색한다. cu_drafts_per_level은 각 레벨까지의 누적 드래프트 수, child_drafts_per_level은 각 부모 노드당 자식 수이다. 이 사전 계산으로 트리 탐색 시 인덱싱 오버헤드를 제거한다.
CUDA Graph 지원
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self.input_ids = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=device
)
EAGLE 드래프트 모델도 CUDA graph로 캡처하여 커널 런칭 오버헤드를 제거한다. 드래프트 모델이 작기 때문에 커널 런칭 비용이 상대적으로 크므로, CUDA graph의 효과가 더 두드러진다.
왜 이 설계인가
-
높은 수락률: 은닉 상태에는 다음 토큰 예측에 필요한 대부분의 정보가 이미 있다. 이를 활용하면 독립 드래프트 모델 대비 수락률이 70-90%까지 올라간다.
-
경량 드래프트 모델: EAGLE의 드래프트 모델은 타겟 모델의 1-2개 Transformer 레이어로 구성되어 매우 가볍다. 은닉 상태가 이미 풍부한 정보를 제공하기 때문이다.
-
트리 드래프팅: 분기 탐색으로 수락률을 더 높인다. GPU 병렬성을 활용하여 추가 연산 비용 대비 높은 수락 토큰 수를 달성한다.
-
Reject 안전성: slot_mapping 마스킹으로 거부된 토큰의 KV 캐시가 완벽하게 무효화된다. 이 덕분에 정확성이 보장되면서도 캐시 관리가 단순하다.
EAGLE은 "타겟 모델이 이미 계산한 정보를 재활용"한다는 직관적이면서도 강력한 아이디어로, 현재 vLLM에서 가장 널리 사용되는 speculative decoding 방식이다.
논문 핵심 내용
EAGLE-2 논문은 EAGLE의 컨텍스트 인식 동적 드래프트 트리를 도입하여 speculative decoding의 효율을 한 단계 더 끌어올렸다. EAGLE-1이 고정된 트리 구조를 사용했다면, EAGLE-2는 각 디코딩 스텝마다 드래프트 토큰의 신뢰도를 기반으로 트리를 동적으로 구성한다.
벤치마크 결과를 보면, EAGLE-2는 기존 EAGLE 대비 25-39%의 추가 속도 향상을 달성했다:
| 모델 | 벤치마크 | EAGLE 속도비 | EAGLE-2 속도비 | 개선폭 |
|---|---|---|---|---|
| Vicuna-7B | MT-bench (temp=0) | 2.90x | 3.62x | +25% |
| Vicuna-13B | MT-bench (temp=0) | 3.07x | 4.26x | +39% |
| LLaMA2-Chat 13B | MT-bench (temp=0) | 3.03x | 4.21x | +39% |
평균 수락 길이(acceptance length) 비교:
| 모델 | EAGLE | EAGLE-2 |
|---|---|---|
| Vicuna-7B | 3.94 | 4.98 |
| Vicuna-13B | 3.98 | 4.83 |
| LLaMA2-Chat 13B | 3.90 | 4.75 |
다른 speculative decoding 방법들과 비교하면, EAGLE-2는 Medusa 대비 약 2배, Lookahead 대비 약 2.3배 빠르다. 전체적으로 3.05x-4.26x 범위의 속도 향상을 보이며, 이는 lossless한 가속(출력 분포가 원본과 동일)이라는 점에서 매우 인상적이다.
관련 포스트
vLLM 의 다른글
- 이전글 [vLLM] Speculative Decoding: 드래프트 모델로 LLM 디코딩을 가속하는 원리
- 현재글 : [vLLM] EAGLE: 은닉 상태 기반 드래프트로 Speculative Decoding을 강화하다
- 다음글 [vLLM] Sampler: logits에서 토큰까지, 샘플링 파이프라인 전체 분석
댓글