본문으로 건너뛰기

[vLLM] Beam Search: 빔 서치 디코딩 구현 분석

들어가며

LLM 추론에서 디코딩 전략은 출력 품질을 결정하는 핵심 요소다. Greedy decoding은 빠르지만 최적 시퀀스를 놓칠 수 있고, 빔 서치는 여러 후보를 동시에 추적하여 더 나은 결과를 얻는다. vLLM은 vllm/beam_search.py에서 이 빔 서치 로직을 구현하고 있다.

핵심 구조/코드 분석

BeamSearchSequence

빔 서치의 기본 단위는 BeamSearchSequence 데이터클래스다.

@dataclass
class BeamSearchSequence:
    orig_prompt: TokensInput | MultiModalInput | EncoderDecoderInput
    tokens: list[int]
    logprobs: list[dict[int, Logprob]]
    lora_request: LoRARequest | None = None
    cum_logprob: float = 0.0
    text: str | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = None

각 빔은 원본 프롬프트(orig_prompt), 현재까지 생성된 토큰(tokens), 누적 로그 확률(cum_logprob)을 추적한다. text 필드는 최종 반환 시에만 채워진다.

프롬프트 재구성: get_prompt()

빔 서치에서 각 빔은 서로 다른 토큰 시퀀스를 가지므로, 매 스텝마다 프롬프트를 재구성해야 한다.

def get_prompt(self):
    prompt = self.orig_prompt
    if prompt["type"] == "enc_dec":
        return self._build_encoder_decoder_inputs(prompt)
    if prompt["type"] == "token":
        return tokens_input(self.tokens, prompt=prompt_text, cache_salt=cache_salt)
    return mm_input(
        prompt_token_ids=self.tokens,
        mm_kwargs=prompt["mm_kwargs"],
        mm_hashes=prompt["mm_hashes"],
        mm_placeholders=prompt["mm_placeholders"],
        prompt=prompt_text,
        cache_salt=cache_salt,
    )

token, multimodal, enc_dec 세 가지 타입을 모두 지원한다. Encoder-decoder 모델의 경우 인코더 프롬프트는 유지하고 디코더 프롬프트만 현재 빔의 토큰으로 교체한다.

스코어링 함수와 Length Penalty

빔 정렬에 사용되는 스코어링 함수는 HuggingFace Transformers의 구현을 참고했다.

def get_beam_search_score(
    tokens: list[int],
    cumulative_logprob: float,
    eos_token_id: int,
    length_penalty: float = 1.0,
) -> float:
    seq_len = len(tokens)
    if tokens[-1] == eos_token_id:
        seq_len -= 1
    return cumulative_logprob / (seq_len**length_penalty)

핵심은 length_penalty 파라미터다. 누적 로그 확률을 시퀀스 길이의 거듭제곱으로 나누어, 짧은 시퀀스에 대한 편향을 조절한다. EOS 토큰은 길이 계산에서 제외된다.

BeamSearchInstance

class BeamSearchInstance:
    def __init__(self, prompt, lora_request=None, logprobs=None, **kwargs):
        self.beams: list[BeamSearchSequence] = [
            BeamSearchSequence(
                orig_prompt=prompt,
                tokens=initial_tokens,
                logprobs=[] if logprobs is None else list(logprobs),
                lora_request=lora_request,
                **kwargs,
            )
        ]
        self.completed: list[BeamSearchSequence] = []

하나의 요청에 대한 빔 서치 상태를 관리한다. beams는 아직 생성 중인 활성 빔, completed는 EOS에 도달한 완료 빔을 담는다.

왜 이 설계인가

  1. 프롬프트 재구성 방식: 빔마다 독립적인 요청으로 처리하기 때문에 vLLM의 continuous batching과 자연스럽게 통합된다. 각 빔이 별도의 요청처럼 스케줄링되므로, 배치 내 다른 요청과 함께 효율적으로 처리할 수 있다.

  2. Encoder-decoder 분리 처리: 인코더 출력은 모든 빔에서 동일하므로, 디코더 프롬프트만 교체하는 것이 올바른 설계다. 다만 코드 주석에서 인코더 멀티모달 캐시가 아직 제대로 연결되지 않아 매 빔마다 인코더를 다시 실행하는 문제가 있음을 언급하고 있다.

  3. Length penalty 스코어링: 빔 서치는 길이가 긴 시퀀스일수록 로그 확률이 낮아지는 경향이 있어, length penalty 없이는 짧은 시퀀스를 선호하게 된다. length_penalty > 1.0이면 긴 시퀀스에 유리하고, < 1.0이면 짧은 시퀀스에 유리하다.

참고 자료

댓글

관련 포스트

vLLM 의 다른글