[vLLM] MTP & DFlash: 다중 토큰 예측과 Flash 기반 드래프팅
들어가며
투기적 디코딩(Speculative Decoding)은 작은 드래프트 모델로 여러 토큰을 미리 생성한 뒤, 큰 타겟 모델로 한 번에 검증하는 방식이다. vLLM의 DFlash는 이를 Flash Attention 기반의 병렬 드래프팅으로 구현한다. Multi-Token Prediction(MTP) 논문(arxiv:2404.19737)의 아이디어를 활용하여, 드래프트 모델이 한 번의 forward pass에서 여러 토큰을 동시에 제안한다.
공식 문서
vLLM 공식 문서: MTP Speculative Decoding
핵심 구조/코드 분석
DFlashProposer 초기화
class DFlashProposer(SpecDecodeBaseProposer):
def __init__(self, vllm_config, device, runner=None):
assert vllm_config.speculative_config.method == "dflash"
super().__init__(
vllm_config=vllm_config, device=device,
pass_hidden_states_to_model=True, runner=runner,
)
self.max_query_tokens = self.max_batch_size * (1 + self.num_speculative_tokens)
self.max_positions = self.max_num_tokens + self.max_query_tokens
핵심 포인트는 pass_hidden_states_to_model=True다. DFlash는 타겟 모델의 hidden states를 드래프트 모델의 context로 사용한다. max_query_tokens는 배치 크기 * (1 + 투기 토큰 수)인데, 1은 bonus 토큰(next token)이고 나머지는 mask 토큰이다.
Cross-Attention 기반 병렬 드래프팅
set_inputs_first_pass 메서드가 DFlash의 핵심이다.
def set_inputs_first_pass(self, target_token_ids, next_token_ids,
target_positions, target_hidden_states, token_indices_to_sample,
cad, num_rejected_tokens_gpu):
batch_size = cad.batch_size()
num_context = target_token_ids.shape[0]
num_query_per_req = 1 + self.num_speculative_tokens
num_query_total = batch_size * num_query_per_req
# Triton 커널로 입력 구성
copy_and_expand_dflash_inputs_kernel[grid](
next_token_ids_ptr=next_token_ids,
target_positions_ptr=target_positions,
parallel_drafting_token_id=self.parallel_drafting_token_id,
block_size=self.block_size,
num_query_per_req=num_query_per_req,
...
)
타겟 모델의 hidden states를 context KV로 사용하고, query는 bonus 토큰 + mask 토큰으로 구성한다. Triton 커널(copy_and_expand_dflash_inputs_kernel)로 input_ids, positions, slot_mapping, token_indices를 한 번에 구성한다.
Non-Causal Attention
new_cad = CommonAttentionMetadata(
...
causal=False, # DFlash에서는 non-causal attention 필수
)
DFlash가 일반적인 autoregressive 어텐션과 다른 핵심은 non-causal attention이다. 각 mask 토큰이 모든 context states를 동시에 볼 수 있어야 병렬 드래프팅이 가능하다.
Context KV 사전 삽입
def build_model_inputs_first_pass(self, num_tokens, num_input_tokens, mm_embed_inputs):
num_context = self._dflash_num_context
# 타겟 hidden states -> KV projection -> cache에 직접 삽입
self.model.precompute_and_store_context_kv(
self._dflash_hidden_states,
self._context_positions_buffer[:num_context],
self._context_slot_mapping_buffer[:num_context],
)
타겟 모델의 hidden states를 KV projection(GEMM + norms + RoPE)하여 KV cache에 직접 삽입한다. 이렇게 하면 드래프트 모델의 forward pass에서는 query 토큰만 처리하면 되므로 연산량이 크게 줄어든다.
DraftModelProposer (별도 드래프트 모델)
class DraftModelProposer(SpecDecodeBaseProposer):
def _get_model(self) -> nn.Module:
draft_vllm_config = self._create_draft_vllm_config()
with set_model_tag("draft_model"):
model = get_model(vllm_config=draft_vllm_config, prefix="draft_model")
return model
DFlash와 달리, DraftModelProposer는 완전히 별개의 작은 모델을 드래프트로 사용한다. TP 크기가 타겟과 동일해야 하는 제약이 있다.
왜 이 설계인가
-
병렬 드래프팅의 효율성: 전통적인 autoregressive 드래프팅은 N개의 토큰을 생성하려면 N번의 forward pass가 필요하다. DFlash는 cross-attention 구조로 한 번에 모든 토큰을 생성하여, 드래프팅 지연을
O(N)에서O(1)로 줄인다. -
별도 버퍼 분리: Context와 query의 slot_mapping, positions 버퍼를 분리한 이유는 CUDA 그래프 호환성 때문이다. Query 버퍼 주소가 안정적이어야 CUDA 그래프 캡처가 가능하다.
-
Multimodal 지원: DFlash는
_raise_if_multimodal을 오버라이드하여 멀티모달 입력을 허용한다. Qwen3.5 같은 멀티모달 모델에서도 투기적 디코딩이 가능하도록 설계했다.
참고 자료
관련 포스트
vLLM 의 다른글
- 이전글 [vLLM] Beam Search: 빔 서치 디코딩 구현 분석
- 현재글 : [vLLM] MTP & DFlash: 다중 토큰 예측과 Flash 기반 드래프팅
- 다음글 [vLLM] KV Cache Offloading: GPU에서 CPU로의 KV 캐시 오프로딩
댓글