본문으로 건너뛰기

[vLLM] Medusa: 다중 예측 헤드 투기적 디코딩

들어가며

Medusa는 투기적 디코딩(Speculative Decoding)의 한 방법으로, 별도의 드래프트 모델 없이 타겟 모델 자체에 여러 개의 예측 헤드를 추가하여 다음 여러 토큰을 동시에 예측한다. 일반적인 투기적 디코딩이 별도의 작은 모델을 필요로 하는 것과 달리, Medusa는 타겟 모델의 hidden state를 재사용하므로 메모리 효율적이다.

소스 경로: vllm/v1/spec_decode/medusa.py

논문: Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

공식 문서

vLLM 공식 문서: Speculative Decoding

핵심 구조/코드 분석

MedusaProposer 클래스

class MedusaProposer:
    """Medusa proposer class for generating token sequences"""

    def __init__(self, vllm_config: VllmConfig, device: torch.device):
        self.vllm_config = vllm_config
        self.spec_config = vllm_config.speculative_config
        self.device = device
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
        self.hidden_size = self.spec_config.draft_model_config.get_hidden_size()
        self.dtype = vllm_config.model_config.dtype

MedusaProposerspec_config에서 드래프트 모델 설정을 가져온다. 여기서 draft_model_config는 Medusa 헤드의 설정을 담고 있으며, hidden_size는 타겟 모델의 hidden state 차원과 동일하다.

propose() - 드래프트 토큰 생성

def propose(
    self,
    target_hidden_states: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    slot_mappings=None,
) -> torch.Tensor:
    # Medusa 헤드를 통해 블록 생성 및 로짓 계산
    blocks = self.model(target_hidden_states)
    logits = self.model.compute_logits(blocks)

    # 각 Medusa 헤드의 argmax를 스택하여 드래프트 토큰 생성
    # Shape: [batch_size, num_heads]
    draft_tokens = torch.stack(
        [logit.argmax(dim=-1) for logit in logits], dim=1
    )
    return draft_tokens

핵심 로직은 놀랍도록 간단하다:

  1. 타겟 모델의 hidden_states를 Medusa 모델에 입력
  2. 각 Medusa 헤드가 다음 위치의 로짓을 독립적으로 예측
  3. 각 헤드에서 argmax로 가장 확률 높은 토큰을 선택
  4. 모든 헤드의 예측을 스택하여 [batch_size, num_heads] 형태로 반환

모델 로딩

def load_model(self, target_model: nn.Module) -> None:
    from vllm.compilation.backends import set_model_tag

    with set_model_tag("medusa_head"):
        self.model = get_model(
            vllm_config=self.vllm_config,
            model_config=self.spec_config.draft_model_config,
        )
    assert not (
        is_mixture_of_experts(self.model)
        and self.vllm_config.parallel_config.enable_eplb
    ), "EPLB for Medusa is not supported"

set_model_tag("medusa_head")로 컴파일 시 Medusa 헤드를 별도로 식별한다. 이를 통해 torch.compile이 메인 모델과 Medusa 헤드를 독립적으로 최적화할 수 있다. MoE 모델의 EPLB(Expert Parallelism Load Balancing)는 아직 지원되지 않는다.

더미 실행

@torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None:
    hidden_states = torch.zeros(
        (self.max_num_tokens, self.hidden_size),
        dtype=self.dtype, device=self.device,
    )
    with set_forward_context(None, self.vllm_config, num_tokens=num_tokens):
        self.model(hidden_states)

메모리 프로파일링과 CUDA 그래프 캡처를 위해 더미 실행을 수행한다. set_forward_context를 사용하여 실제 추론 시와 동일한 환경에서 실행된다.

왜 이 설계인가

  1. Hidden State 재사용: Medusa는 타겟 모델의 마지막 레이어 hidden state를 입력으로 사용한다. 별도의 드래프트 모델 포워드 패스가 필요 없으므로 오버헤드가 매우 작다. 논문에 따르면 2-3배의 속도 향상을 달성할 수 있다.

  2. Greedy Drafting: argmax를 사용하여 각 헤드에서 가장 확률 높은 토큰만 선택한다. Tree attention 기반의 복잡한 검증 과정 없이도 높은 수락률을 달성한다.

  3. GPUModelRunner 통합: MedusaProposer는 GPUModelRunner에서 init_speculator()를 통해 초기화되며, 메인 모델 포워드 패스 직후에 호출된다. RejectionSampler가 드래프트 토큰을 검증하여 수락/거절을 결정한다.

  4. 별도 컴파일 태그: "medusa_head" 태그를 사용하여 torch.compile 캐시를 메인 모델과 분리한다. 이는 Medusa 헤드만 업데이트해도 메인 모델의 컴파일 캐시를 무효화하지 않도록 한다.

논문 핵심 내용

Medusa 논문은 별도의 드래프트 모델 없이 타겟 모델 자체에 경량 예측 헤드를 추가하는 방식으로 speculative decoding을 구현했다. Medusa-1은 품질 저하 없이 2.2배 이상의 속도 향상을, Medusa-2는 특수한 학습 방법을 적용하여 2.3-3.6배의 속도 향상을 달성했다.

모델별 벤치마크

모델 Medusa-2 속도비 평균 수락 토큰 MT-Bench 품질
Vicuna-7B 2.83x 3.47 6.18
Vicuna-13B 2.83x 3.51 6.43
Vicuna-33B 2.35x 3.01 7.18
Zephyr-7B 2.66x 3.14 7.25

기존 Speculative Decoding 대비 성능

모델 Medusa-2 Speculative Decoding (별도 모델)
Vicuna-7B 2.83x 1.47x
Vicuna-13B 2.83x 1.56x
Vicuna-33B 2.35x 1.60x

Medusa가 별도 드래프트 모델 기반 방식보다 거의 2배 가까이 빠르다는 것은 인상적이다.

Tree Attention 기법별 누적 효과

기법 속도비
Medusa 헤드 단독 (tree attention 없음) ~1.5x
+ Tree attention 추가 ~1.9x
+ 최적화된 트리 구성 ~2.2x
+ Medusa-2 학습 ~2.8x

AlpacaEval 데이터셋에서는 Vicuna-13B가 3.16배 속도 향상을 달성하며 가장 높은 성능을 보였다. 카테고리별로는 코딩(3.29x)과 추출(3.62x) 태스크에서 특히 높은 속도 향상이 관찰되었는데, 이는 이런 태스크들이 상대적으로 예측 가능한 패턴을 가지고 있기 때문이다.

참고

댓글

관련 포스트

vLLM 의 다른글