본문으로 건너뛰기

[vLLM] Sampler: logits에서 토큰까지, 샘플링 파이프라인 전체 분석

들어가며

LLM의 출력은 어휘 크기(vocabulary size)만큼의 logits 벡터이다. 이것을 실제 토큰으로 변환하는 과정에서 temperature, top-k, top-p, repetition penalty, logprobs 등 다양한 처리가 필요하다. vLLM의 Sampler는 이 모든 과정을 GPU에서 배치로 처리하는 파이프라인이다.

핵심 구조/코드 분석

Sampler 클래스의 9단계 파이프라인

vllm/v1/sample/sampler.py의 docstring이 전체 파이프라인을 정의한다:

class Sampler(nn.Module):
    """
    1. logprobs가 요청되면: raw logprobs 또는 raw logits 복제
    2. logits를 float32로 변환
    3. allowed token ids 화이트리스트 적용
    4. bad words 제외
    5. argmax-invariant가 아닌 logit processors 적용
       a) Min tokens processor
       b) Logit bias processor
    6. 페널티 적용
       a) Repetition penalty
       b) Frequency penalty
       c) Presence penalty
    7. 샘플링: greedy 또는 random
       a) temperature 적용
       b) argmax-invariant logit processors (min_p)
       c) top_k / top_p 적용
       d) 확률 분포에서 샘플링
    8. top logprobs 수집
    9. SamplerOutput 반환
    """

forward 메서드: 파이프라인 실행

def forward(self, logits, sampling_metadata, ...):
    # 1. raw logprobs 저장 (나중에 반환용)
    num_logprobs = sampling_metadata.max_num_logprobs
    if num_logprobs is not None:
        if logprobs_mode == "raw_logprobs":
            raw_logprobs = self.compute_logprobs(logits)
        elif logprobs_mode == "raw_logits":
            raw_logprobs = logits.to(torch.float32)

    # 2. float32로 변환
    logits = logits.to(torch.float32)

    # 3-6. logit processors + penalties 적용
    logits = self.apply_logits_processors(
        logits, sampling_metadata, predict_bonus_token
    )

    # 7. 샘플링
    sampled, processed_logprobs = self.sample(logits, sampling_metadata)
    sampled = sampled.long()  # int64로 변환

    # 8-9. logprobs 수집 및 반환
    sampled = sampled.to(torch.int32)
    return SamplerOutput(
        sampled_token_ids=sampled.unsqueeze(-1),
        logprobs_tensors=logprobs_tensors,
    )

Greedy vs Random 샘플링 분기

def sample(self, logits, sampling_metadata):
    assert not (sampling_metadata.all_greedy
                and sampling_metadata.all_random)

    if sampling_metadata.all_random:
        greedy_sampled = None
    else:
        greedy_sampled = self.greedy_sample(logits)
        if sampling_metadata.all_greedy:
            return greedy_sampled, processed_logprobs

    # Temperature 적용
    logits = self.apply_temperature(
        logits, sampling_metadata.temperature,
        sampling_metadata.all_random
    )

    # Argmax-invariant processors (min_p 등)
    for processor in sampling_metadata.logitsprocs.argmax_invariant:
        logits = processor.apply(logits)

    # Top-k/Top-p 샘플링
    random_sampled, processed_logprobs = self.topk_topp_sampler(
        logits, sampling_metadata.generators,
        sampling_metadata.top_k, sampling_metadata.top_p,
    )

    # Greedy와 Random 결과 합성
    if greedy_sampled is None:
        return random_sampled, processed_logprobs

    sampled = torch.where(
        sampling_metadata.temperature < _SAMPLING_EPS,
        greedy_sampled,
        random_sampled,
        out=greedy_sampled,
    )
    return sampled, processed_logprobs

배치 내에서 요청마다 다른 temperature를 가질 수 있다. temperature가 epsilon(1e-5) 미만이면 greedy 결과를, 아니면 random 결과를 사용한다. torch.where로 배치 전체를 한 번에 처리한다.

Temperature 적용

@staticmethod
def apply_temperature(logits, temp, all_random):
    if not all_random:
        temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
    return logits.div_(temp.unsqueeze(dim=1))

in-place division(div_)을 사용하여 새 텐서 할당을 방지한다. greedy 요청(temp ≈ 0)은 1.0으로 치환해서 0으로 나누는 것을 방지한다.

페널티 적용

@staticmethod
def apply_penalties(logits, sampling_metadata, output_token_ids):
    if sampling_metadata.no_penalties:
        return logits
    return apply_all_penalties(
        logits,
        sampling_metadata.prompt_token_ids,
        sampling_metadata.presence_penalties,
        sampling_metadata.frequency_penalties,
        sampling_metadata.repetition_penalties,
        output_token_ids,
    )

페널티가 필요 없는 경우 완전히 건너뛴다(no_penalties 플래그). 이것은 대부분의 서빙 시나리오에서 불필요한 연산을 제거하는 최적화이다.

Logprobs 수집

@staticmethod
def gather_logprobs(logprobs, num_logprobs, token_ids):
    # Top-k logprobs 추출
    topk_logprobs, topk_indices = torch.topk(
        logprobs, num_logprobs, dim=-1
    )
    # 실제 샘플된 토큰의 logprob
    token_logprobs = logprobs.gather(-1, token_ids.unsqueeze(-1))
    # 실제 토큰의 순위 계산
    token_ranks = batched_count_greater_than(logprobs, token_logprobs)
    # 결합
    indices = torch.cat((token_ids.unsqueeze(-1), topk_indices), dim=1)
    logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
    return LogprobsTensors(indices.to(torch.int32), logprobs, token_ranks)

batched_count_greater_than으로 샘플된 토큰이 전체 어휘에서 몇 번째로 확률이 높은지(rank)를 계산한다. 이 정보는 모델의 "확신도"를 파악하는 데 유용하다.

왜 이 설계인가

  1. 배치 효율: 모든 요청을 하나의 텐서로 묶어서 처리한다. 요청마다 다른 sampling 파라미터를 가져도 torch.where와 마스킹으로 분기 없이 처리한다.

  2. 조기 종료 최적화: all_greedy이면 temperature/top-k/top-p를 건너뛰고, no_penalties이면 페널티를 건너뛴다. 대부분의 실제 워크로드에서 이 최적화가 적용된다.

  3. 메모리 효율: in-place 연산(div_, masked_fill_)으로 임시 텐서 할당을 최소화한다. 큰 배치에서 어휘 크기(128K+)를 곱하면 절약량이 상당하다.

  4. Logprobs와 샘플링의 분리: raw logprobs를 먼저 저장하고 이후에 sampling-specific 변환을 적용한다. OpenAI API의 logprobs 스펙과의 호환성을 유지하면서도 효율적인 구현을 달성한다.

Sampler는 vLLM에서 유일하게 "모든 요청이 반드시 거치는" 컴포넌트이다. 따라서 매 토큰마다 호출되는 이 모듈의 효율성이 전체 서빙 성능에 직접적인 영향을 미친다.

댓글

관련 포스트

vLLM 의 다른글