본문으로 건너뛰기

[vLLM] MTP & DFlash: 다중 토큰 예측과 Flash 기반 드래프팅

들어가며

투기적 디코딩(Speculative Decoding)은 작은 드래프트 모델로 여러 토큰을 미리 생성한 뒤, 큰 타겟 모델로 한 번에 검증하는 방식이다. vLLM의 DFlash는 이를 Flash Attention 기반의 병렬 드래프팅으로 구현한다. Multi-Token Prediction(MTP) 논문(arxiv:2404.19737)의 아이디어를 활용하여, 드래프트 모델이 한 번의 forward pass에서 여러 토큰을 동시에 제안한다.

공식 문서

vLLM 공식 문서: MTP Speculative Decoding

핵심 구조/코드 분석

DFlashProposer 초기화

class DFlashProposer(SpecDecodeBaseProposer):
    def __init__(self, vllm_config, device, runner=None):
        assert vllm_config.speculative_config.method == "dflash"
        super().__init__(
            vllm_config=vllm_config, device=device,
            pass_hidden_states_to_model=True, runner=runner,
        )
        self.max_query_tokens = self.max_batch_size * (1 + self.num_speculative_tokens)
        self.max_positions = self.max_num_tokens + self.max_query_tokens

핵심 포인트는 pass_hidden_states_to_model=True다. DFlash는 타겟 모델의 hidden states를 드래프트 모델의 context로 사용한다. max_query_tokens배치 크기 * (1 + 투기 토큰 수)인데, 1은 bonus 토큰(next token)이고 나머지는 mask 토큰이다.

Cross-Attention 기반 병렬 드래프팅

set_inputs_first_pass 메서드가 DFlash의 핵심이다.

def set_inputs_first_pass(self, target_token_ids, next_token_ids,
    target_positions, target_hidden_states, token_indices_to_sample,
    cad, num_rejected_tokens_gpu):
    batch_size = cad.batch_size()
    num_context = target_token_ids.shape[0]
    num_query_per_req = 1 + self.num_speculative_tokens
    num_query_total = batch_size * num_query_per_req

    # Triton 커널로 입력 구성
    copy_and_expand_dflash_inputs_kernel[grid](
        next_token_ids_ptr=next_token_ids,
        target_positions_ptr=target_positions,
        parallel_drafting_token_id=self.parallel_drafting_token_id,
        block_size=self.block_size,
        num_query_per_req=num_query_per_req,
        ...
    )

타겟 모델의 hidden states를 context KV로 사용하고, query는 bonus 토큰 + mask 토큰으로 구성한다. Triton 커널(copy_and_expand_dflash_inputs_kernel)로 input_ids, positions, slot_mapping, token_indices를 한 번에 구성한다.

Non-Causal Attention

new_cad = CommonAttentionMetadata(
    ...
    causal=False,  # DFlash에서는 non-causal attention 필수
)

DFlash가 일반적인 autoregressive 어텐션과 다른 핵심은 non-causal attention이다. 각 mask 토큰이 모든 context states를 동시에 볼 수 있어야 병렬 드래프팅이 가능하다.

Context KV 사전 삽입

def build_model_inputs_first_pass(self, num_tokens, num_input_tokens, mm_embed_inputs):
    num_context = self._dflash_num_context
    # 타겟 hidden states -> KV projection -> cache에 직접 삽입
    self.model.precompute_and_store_context_kv(
        self._dflash_hidden_states,
        self._context_positions_buffer[:num_context],
        self._context_slot_mapping_buffer[:num_context],
    )

타겟 모델의 hidden states를 KV projection(GEMM + norms + RoPE)하여 KV cache에 직접 삽입한다. 이렇게 하면 드래프트 모델의 forward pass에서는 query 토큰만 처리하면 되므로 연산량이 크게 줄어든다.

DraftModelProposer (별도 드래프트 모델)

class DraftModelProposer(SpecDecodeBaseProposer):
    def _get_model(self) -> nn.Module:
        draft_vllm_config = self._create_draft_vllm_config()
        with set_model_tag("draft_model"):
            model = get_model(vllm_config=draft_vllm_config, prefix="draft_model")
        return model

DFlash와 달리, DraftModelProposer는 완전히 별개의 작은 모델을 드래프트로 사용한다. TP 크기가 타겟과 동일해야 하는 제약이 있다.

왜 이 설계인가

  1. 병렬 드래프팅의 효율성: 전통적인 autoregressive 드래프팅은 N개의 토큰을 생성하려면 N번의 forward pass가 필요하다. DFlash는 cross-attention 구조로 한 번에 모든 토큰을 생성하여, 드래프팅 지연을 O(N)에서 O(1)로 줄인다.

  2. 별도 버퍼 분리: Context와 query의 slot_mapping, positions 버퍼를 분리한 이유는 CUDA 그래프 호환성 때문이다. Query 버퍼 주소가 안정적이어야 CUDA 그래프 캡처가 가능하다.

  3. Multimodal 지원: DFlash는 _raise_if_multimodal을 오버라이드하여 멀티모달 입력을 허용한다. Qwen3.5 같은 멀티모달 모델에서도 투기적 디코딩이 가능하도록 설계했다.

참고 자료

댓글

관련 포스트

vLLM 의 다른글