본문으로 건너뛰기

[SGLang] PenaltyLib: 반복/빈도/존재 페널티 구현

들어가며

LLM이 같은 구절을 반복하는 것은 서빙 품질의 치명적 문제다. SGLang의 PenaltyLib는 repetition, frequency, presence 세 가지 페널티와 min_new_tokens 강제를 배치 단위로 효율적으로 적용하는 라이브러리다.

이 글에서는 python/sglang/srt/sampling/penaltylib/ 디렉토리를 중심으로 Orchestrator 패턴과 각 페널티의 구현을 분석한다.

전체 아키텍처

BatchedPenalizerOrchestrator
├── BatchedRepetitionPenalizer   (곱셈, scaling)
├── BatchedFrequencyPenalizer    (뺄셈, additive)
├── BatchedPresencePenalizer     (뺄셈, additive)
└── BatchedMinNewTokensPenalizer (EOS 차단, additive)

apply(logits) 호출 시:
┌────────────────────────────────────────────┐
│ logits (B, V)                              │
│   │                                        │
│   ├─ RepetitionPenalizer._apply(logits)    │
│   │   logits < 0: logits *= penalty        │
│   │   logits > 0: logits /= penalty        │
│   │                                        │
│   ├─ FrequencyPenalizer._apply(logits)     │
│   │   logits -= cumulated_freq_penalties   │
│   │                                        │
│   ├─ PresencePenalizer._apply(logits)      │
│   │   logits -= cumulated_pres_penalties   │
│   │                                        │
│   └─ MinNewTokensPenalizer._apply(logits)  │
│       EOS logits = -inf (if too short)     │
└────────────────────────────────────────────┘

핵심 코드 분석

Orchestrator: 페널티 관리자

BatchedPenalizerOrchestrator는 배치 내 모든 요청의 페널티를 관리한다. weakref를 사용하여 ScheduleBatch 순환 참조를 방지한다.

class BatchedPenalizerOrchestrator:
    def __init__(self, vocab_size, batch, penalizers):
        self.vocab_size = vocab_size
        self._batch_ref = weakref.ref(batch)
        self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}

        is_required = False
        for penalizer in self.penalizers.values():
            pen_is_required = penalizer.prepare_if_required()
            is_required |= pen_is_required
        self.is_required = is_required

각 페널티는 필요한 경우에만 prepare되어 불필요한 텐서 할당을 방지한다.

Repetition Penalty: 스케일링 방식

Repetition penalty는 곱셈 기반이다. 양수 logits는 나누고, 음수 logits는 곱하여 출력된 토큰의 확률을 줄인다.

@torch.compile(dynamic=True, backend=get_compiler_backend())
def apply_scaling_penalties(logits, scaling_penalties):
    logits[:] = torch.where(
        logits < 0,
        logits * scaling_penalties,
        logits / scaling_penalties,
    )

누적 방식은 scatter_로 구현된다. 출력된 토큰 위치에 페널티 값을 기록한다.

def _cumulate_output_tokens(self, output_ids):
    self.cumulated_repetition_penalties.scatter_(
        dim=1,
        index=output_ids.unsqueeze(1),
        src=self.repetition_penalties,
    )

초기값이 1인 (B, V) 텐서에서 출력된 토큰 위치만 penalty 값으로 덮어쓴다. 같은 토큰이 여러 번 나와도 penalty 값은 동일하게 유지된다(빈도 무관).

Frequency Penalty: 빈도 비례 감점

Frequency penalty는 토큰 등장 횟수에 비례하여 logits를 감소시킨다.

class BatchedFrequencyPenalizer(_BatchedPenalizer):
    def _cumulate_output_tokens(self, output_ids):
        self.cumulated_frequency_penalties.scatter_add_(
            dim=1,
            index=output_ids.unsqueeze(1),
            src=self.frequency_penalties,
        )

    def _apply(self, logits):
        logits.sub_(self.cumulated_frequency_penalties)

scatter_add_를 사용하므로 같은 토큰이 N번 등장하면 penalty도 N배가 된다. repetition_penalty와의 차이점이 바로 이것이다.

Presence Penalty: 등장 여부 감점

Presence penalty는 토큰의 등장 여부만 판단한다.

class BatchedPresencePenalizer(_BatchedPenalizer):
    def _cumulate_output_tokens(self, output_ids):
        self.cumulated_presence_penalties.scatter_(
            dim=1,
            index=output_ids.unsqueeze(1),
            src=self.presence_penalties,
        )

    def _apply(self, logits):
        logits.sub_(self.cumulated_presence_penalties)

scatter_(add가 아닌)를 사용하므로 같은 토큰이 여러 번 나와도 penalty가 누적되지 않는다.

Min New Tokens: EOS 차단

최소 토큰 수를 채우기 전에 EOS 토큰을 선택하지 못하게 막는다.

class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
    def _apply(self, logits):
        mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
        logits[mask] += self.stop_token_penalties[mask]

stop_token_penalties는 EOS와 stop 토큰 위치에 -inf를 미리 채워둔 텐서다. 최소 길이에 도달하기 전까지 이 값을 더하여 해당 토큰 선택을 차단한다.

세 가지 페널티 비교

토큰 "hello"가 3번 출력된 상태:

Repetition (penalty=1.2):
  logit=2.0 → 2.0 / 1.2 = 1.67  (3번이든 1번이든 동일)

Frequency (penalty=0.5):
  logit=2.0 → 2.0 - (0.5 × 3) = 0.5  (횟수에 비례)

Presence (penalty=0.5):
  logit=2.0 → 2.0 - 0.5 = 1.5  (등장 여부만)
특성 Repetition Frequency Presence
방식 곱셈/나눗셈 뺄셈 뺄셈
빈도 반영 X O X
초기값 1.0 0.0 0.0
누적 연산 scatter_ scatter_add_ scatter_

Speculative Decoding 지원

apply() 메서드는 speculative decoding의 draft 토큰 레이아웃을 지원한다.

def apply(self, logits, repeat=None):
    if repeat is None:
        for penalizer in self.penalizers.values():
            penalizer.apply(logits)
    else:
        bs = logits.shape[0] // repeat
        additive = torch.zeros((bs, logits.shape[1]), ...)
        self.accumulate_additive_penalties(additive)
        logits.add_(torch.repeat_interleave(additive, repeat, dim=0))

        accumulated = self.accumulate_scaling_penalties()
        if accumulated is not None:
            expanded = torch.repeat_interleave(accumulated, repeat, dim=0)
            apply_scaling_penalties(logits, expanded)

Additive 페널티(frequency, presence)와 scaling 페널티(repetition)를 분리하여 repeat_interleave로 확장한다.

설계 근거

설계 선택 이유
Orchestrator 패턴 페널티 종류가 확장 가능하며, 필요한 것만 prepare하여 메모리 절약
weakref 사용 ScheduleBatch 순환 참조 방지로 GC 누수 차단
is_multiplicative 분류 speculative decoding에서 additive/scaling 페널티 분리 처리 필요
prepare_if_required 배치 내 모든 요청이 penalty=0이면 텐서 할당 자체를 건너뜀

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글