[vLLM] Beam Search: 빔 서치 디코딩 구현 분석
들어가며
LLM 추론에서 디코딩 전략은 출력 품질을 결정하는 핵심 요소다. Greedy decoding은 빠르지만 최적 시퀀스를 놓칠 수 있고, 빔 서치는 여러 후보를 동시에 추적하여 더 나은 결과를 얻는다. vLLM은 vllm/beam_search.py에서 이 빔 서치 로직을 구현하고 있다.
핵심 구조/코드 분석
BeamSearchSequence
빔 서치의 기본 단위는 BeamSearchSequence 데이터클래스다.
@dataclass
class BeamSearchSequence:
orig_prompt: TokensInput | MultiModalInput | EncoderDecoderInput
tokens: list[int]
logprobs: list[dict[int, Logprob]]
lora_request: LoRARequest | None = None
cum_logprob: float = 0.0
text: str | None = None
finish_reason: str | None = None
stop_reason: int | str | None = None
각 빔은 원본 프롬프트(orig_prompt), 현재까지 생성된 토큰(tokens), 누적 로그 확률(cum_logprob)을 추적한다. text 필드는 최종 반환 시에만 채워진다.
프롬프트 재구성: get_prompt()
빔 서치에서 각 빔은 서로 다른 토큰 시퀀스를 가지므로, 매 스텝마다 프롬프트를 재구성해야 한다.
def get_prompt(self):
prompt = self.orig_prompt
if prompt["type"] == "enc_dec":
return self._build_encoder_decoder_inputs(prompt)
if prompt["type"] == "token":
return tokens_input(self.tokens, prompt=prompt_text, cache_salt=cache_salt)
return mm_input(
prompt_token_ids=self.tokens,
mm_kwargs=prompt["mm_kwargs"],
mm_hashes=prompt["mm_hashes"],
mm_placeholders=prompt["mm_placeholders"],
prompt=prompt_text,
cache_salt=cache_salt,
)
token, multimodal, enc_dec 세 가지 타입을 모두 지원한다. Encoder-decoder 모델의 경우 인코더 프롬프트는 유지하고 디코더 프롬프트만 현재 빔의 토큰으로 교체한다.
스코어링 함수와 Length Penalty
빔 정렬에 사용되는 스코어링 함수는 HuggingFace Transformers의 구현을 참고했다.
def get_beam_search_score(
tokens: list[int],
cumulative_logprob: float,
eos_token_id: int,
length_penalty: float = 1.0,
) -> float:
seq_len = len(tokens)
if tokens[-1] == eos_token_id:
seq_len -= 1
return cumulative_logprob / (seq_len**length_penalty)
핵심은 length_penalty 파라미터다. 누적 로그 확률을 시퀀스 길이의 거듭제곱으로 나누어, 짧은 시퀀스에 대한 편향을 조절한다. EOS 토큰은 길이 계산에서 제외된다.
BeamSearchInstance
class BeamSearchInstance:
def __init__(self, prompt, lora_request=None, logprobs=None, **kwargs):
self.beams: list[BeamSearchSequence] = [
BeamSearchSequence(
orig_prompt=prompt,
tokens=initial_tokens,
logprobs=[] if logprobs is None else list(logprobs),
lora_request=lora_request,
**kwargs,
)
]
self.completed: list[BeamSearchSequence] = []
하나의 요청에 대한 빔 서치 상태를 관리한다. beams는 아직 생성 중인 활성 빔, completed는 EOS에 도달한 완료 빔을 담는다.
왜 이 설계인가
-
프롬프트 재구성 방식: 빔마다 독립적인 요청으로 처리하기 때문에 vLLM의 continuous batching과 자연스럽게 통합된다. 각 빔이 별도의 요청처럼 스케줄링되므로, 배치 내 다른 요청과 함께 효율적으로 처리할 수 있다.
-
Encoder-decoder 분리 처리: 인코더 출력은 모든 빔에서 동일하므로, 디코더 프롬프트만 교체하는 것이 올바른 설계다. 다만 코드 주석에서 인코더 멀티모달 캐시가 아직 제대로 연결되지 않아 매 빔마다 인코더를 다시 실행하는 문제가 있음을 언급하고 있다.
-
Length penalty 스코어링: 빔 서치는 길이가 긴 시퀀스일수록 로그 확률이 낮아지는 경향이 있어, length penalty 없이는 짧은 시퀀스를 선호하게 된다.
length_penalty > 1.0이면 긴 시퀀스에 유리하고,< 1.0이면 짧은 시퀀스에 유리하다.
참고 자료
관련 포스트
vLLM 의 다른글
- 이전글 [vLLM] Sampling Parameters: 전체 샘플링 파라미터 정리
- 현재글 : [vLLM] Beam Search: 빔 서치 디코딩 구현 분석
- 다음글 [vLLM] MTP & DFlash: 다중 토큰 예측과 Flash 기반 드래프팅
댓글