[vLLM] Sampler: logits에서 토큰까지, 샘플링 파이프라인 전체 분석
들어가며
LLM의 출력은 어휘 크기(vocabulary size)만큼의 logits 벡터이다. 이것을 실제 토큰으로 변환하는 과정에서 temperature, top-k, top-p, repetition penalty, logprobs 등 다양한 처리가 필요하다. vLLM의 Sampler는 이 모든 과정을 GPU에서 배치로 처리하는 파이프라인이다.
핵심 구조/코드 분석
Sampler 클래스의 9단계 파이프라인
vllm/v1/sample/sampler.py의 docstring이 전체 파이프라인을 정의한다:
class Sampler(nn.Module):
"""
1. logprobs가 요청되면: raw logprobs 또는 raw logits 복제
2. logits를 float32로 변환
3. allowed token ids 화이트리스트 적용
4. bad words 제외
5. argmax-invariant가 아닌 logit processors 적용
a) Min tokens processor
b) Logit bias processor
6. 페널티 적용
a) Repetition penalty
b) Frequency penalty
c) Presence penalty
7. 샘플링: greedy 또는 random
a) temperature 적용
b) argmax-invariant logit processors (min_p)
c) top_k / top_p 적용
d) 확률 분포에서 샘플링
8. top logprobs 수집
9. SamplerOutput 반환
"""
forward 메서드: 파이프라인 실행
def forward(self, logits, sampling_metadata, ...):
# 1. raw logprobs 저장 (나중에 반환용)
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
if logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif logprobs_mode == "raw_logits":
raw_logprobs = logits.to(torch.float32)
# 2. float32로 변환
logits = logits.to(torch.float32)
# 3-6. logit processors + penalties 적용
logits = self.apply_logits_processors(
logits, sampling_metadata, predict_bonus_token
)
# 7. 샘플링
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
sampled = sampled.long() # int64로 변환
# 8-9. logprobs 수집 및 반환
sampled = sampled.to(torch.int32)
return SamplerOutput(
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=logprobs_tensors,
)
Greedy vs Random 샘플링 분기
def sample(self, logits, sampling_metadata):
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_random:
greedy_sampled = None
else:
greedy_sampled = self.greedy_sample(logits)
if sampling_metadata.all_greedy:
return greedy_sampled, processed_logprobs
# Temperature 적용
logits = self.apply_temperature(
logits, sampling_metadata.temperature,
sampling_metadata.all_random
)
# Argmax-invariant processors (min_p 등)
for processor in sampling_metadata.logitsprocs.argmax_invariant:
logits = processor.apply(logits)
# Top-k/Top-p 샘플링
random_sampled, processed_logprobs = self.topk_topp_sampler(
logits, sampling_metadata.generators,
sampling_metadata.top_k, sampling_metadata.top_p,
)
# Greedy와 Random 결과 합성
if greedy_sampled is None:
return random_sampled, processed_logprobs
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
out=greedy_sampled,
)
return sampled, processed_logprobs
배치 내에서 요청마다 다른 temperature를 가질 수 있다. temperature가 epsilon(1e-5) 미만이면 greedy 결과를, 아니면 random 결과를 사용한다. torch.where로 배치 전체를 한 번에 처리한다.
Temperature 적용
@staticmethod
def apply_temperature(logits, temp, all_random):
if not all_random:
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
return logits.div_(temp.unsqueeze(dim=1))
in-place division(div_)을 사용하여 새 텐서 할당을 방지한다. greedy 요청(temp ≈ 0)은 1.0으로 치환해서 0으로 나누는 것을 방지한다.
페널티 적용
@staticmethod
def apply_penalties(logits, sampling_metadata, output_token_ids):
if sampling_metadata.no_penalties:
return logits
return apply_all_penalties(
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
output_token_ids,
)
페널티가 필요 없는 경우 완전히 건너뛴다(no_penalties 플래그). 이것은 대부분의 서빙 시나리오에서 불필요한 연산을 제거하는 최적화이다.
Logprobs 수집
@staticmethod
def gather_logprobs(logprobs, num_logprobs, token_ids):
# Top-k logprobs 추출
topk_logprobs, topk_indices = torch.topk(
logprobs, num_logprobs, dim=-1
)
# 실제 샘플된 토큰의 logprob
token_logprobs = logprobs.gather(-1, token_ids.unsqueeze(-1))
# 실제 토큰의 순위 계산
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
# 결합
indices = torch.cat((token_ids.unsqueeze(-1), topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
return LogprobsTensors(indices.to(torch.int32), logprobs, token_ranks)
batched_count_greater_than으로 샘플된 토큰이 전체 어휘에서 몇 번째로 확률이 높은지(rank)를 계산한다. 이 정보는 모델의 "확신도"를 파악하는 데 유용하다.
왜 이 설계인가
-
배치 효율: 모든 요청을 하나의 텐서로 묶어서 처리한다. 요청마다 다른 sampling 파라미터를 가져도
torch.where와 마스킹으로 분기 없이 처리한다. -
조기 종료 최적화:
all_greedy이면 temperature/top-k/top-p를 건너뛰고,no_penalties이면 페널티를 건너뛴다. 대부분의 실제 워크로드에서 이 최적화가 적용된다. -
메모리 효율: in-place 연산(
div_,masked_fill_)으로 임시 텐서 할당을 최소화한다. 큰 배치에서 어휘 크기(128K+)를 곱하면 절약량이 상당하다. -
Logprobs와 샘플링의 분리: raw logprobs를 먼저 저장하고 이후에 sampling-specific 변환을 적용한다. OpenAI API의
logprobs스펙과의 호환성을 유지하면서도 효율적인 구현을 달성한다.
Sampler는 vLLM에서 유일하게 "모든 요청이 반드시 거치는" 컴포넌트이다. 따라서 매 토큰마다 호출되는 이 모듈의 효율성이 전체 서빙 성능에 직접적인 영향을 미친다.
관련 포스트
vLLM 의 다른글
- 이전글 [vLLM] EAGLE: 은닉 상태 기반 드래프트로 Speculative Decoding을 강화하다
- 현재글 : [vLLM] Sampler: logits에서 토큰까지, 샘플링 파이프라인 전체 분석
- 다음글 [vLLM] Tensor Parallelism: 거대 모델을 여러 GPU에 나누는 텐서 병렬화
댓글