[SGLang] PenaltyLib: 반복/빈도/존재 페널티 구현
들어가며
LLM이 같은 구절을 반복하는 것은 서빙 품질의 치명적 문제다. SGLang의 PenaltyLib는 repetition, frequency, presence 세 가지 페널티와 min_new_tokens 강제를 배치 단위로 효율적으로 적용하는 라이브러리다.
이 글에서는 python/sglang/srt/sampling/penaltylib/ 디렉토리를 중심으로 Orchestrator 패턴과 각 페널티의 구현을 분석한다.
전체 아키텍처
BatchedPenalizerOrchestrator
├── BatchedRepetitionPenalizer (곱셈, scaling)
├── BatchedFrequencyPenalizer (뺄셈, additive)
├── BatchedPresencePenalizer (뺄셈, additive)
└── BatchedMinNewTokensPenalizer (EOS 차단, additive)
apply(logits) 호출 시:
┌────────────────────────────────────────────┐
│ logits (B, V) │
│ │ │
│ ├─ RepetitionPenalizer._apply(logits) │
│ │ logits < 0: logits *= penalty │
│ │ logits > 0: logits /= penalty │
│ │ │
│ ├─ FrequencyPenalizer._apply(logits) │
│ │ logits -= cumulated_freq_penalties │
│ │ │
│ ├─ PresencePenalizer._apply(logits) │
│ │ logits -= cumulated_pres_penalties │
│ │ │
│ └─ MinNewTokensPenalizer._apply(logits) │
│ EOS logits = -inf (if too short) │
└────────────────────────────────────────────┘
핵심 코드 분석
Orchestrator: 페널티 관리자
BatchedPenalizerOrchestrator는 배치 내 모든 요청의 페널티를 관리한다. weakref를 사용하여 ScheduleBatch 순환 참조를 방지한다.
class BatchedPenalizerOrchestrator:
def __init__(self, vocab_size, batch, penalizers):
self.vocab_size = vocab_size
self._batch_ref = weakref.ref(batch)
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
is_required = False
for penalizer in self.penalizers.values():
pen_is_required = penalizer.prepare_if_required()
is_required |= pen_is_required
self.is_required = is_required
각 페널티는 필요한 경우에만 prepare되어 불필요한 텐서 할당을 방지한다.
Repetition Penalty: 스케일링 방식
Repetition penalty는 곱셈 기반이다. 양수 logits는 나누고, 음수 logits는 곱하여 출력된 토큰의 확률을 줄인다.
@torch.compile(dynamic=True, backend=get_compiler_backend())
def apply_scaling_penalties(logits, scaling_penalties):
logits[:] = torch.where(
logits < 0,
logits * scaling_penalties,
logits / scaling_penalties,
)
누적 방식은 scatter_로 구현된다. 출력된 토큰 위치에 페널티 값을 기록한다.
def _cumulate_output_tokens(self, output_ids):
self.cumulated_repetition_penalties.scatter_(
dim=1,
index=output_ids.unsqueeze(1),
src=self.repetition_penalties,
)
초기값이 1인 (B, V) 텐서에서 출력된 토큰 위치만 penalty 값으로 덮어쓴다. 같은 토큰이 여러 번 나와도 penalty 값은 동일하게 유지된다(빈도 무관).
Frequency Penalty: 빈도 비례 감점
Frequency penalty는 토큰 등장 횟수에 비례하여 logits를 감소시킨다.
class BatchedFrequencyPenalizer(_BatchedPenalizer):
def _cumulate_output_tokens(self, output_ids):
self.cumulated_frequency_penalties.scatter_add_(
dim=1,
index=output_ids.unsqueeze(1),
src=self.frequency_penalties,
)
def _apply(self, logits):
logits.sub_(self.cumulated_frequency_penalties)
scatter_add_를 사용하므로 같은 토큰이 N번 등장하면 penalty도 N배가 된다. repetition_penalty와의 차이점이 바로 이것이다.
Presence Penalty: 등장 여부 감점
Presence penalty는 토큰의 등장 여부만 판단한다.
class BatchedPresencePenalizer(_BatchedPenalizer):
def _cumulate_output_tokens(self, output_ids):
self.cumulated_presence_penalties.scatter_(
dim=1,
index=output_ids.unsqueeze(1),
src=self.presence_penalties,
)
def _apply(self, logits):
logits.sub_(self.cumulated_presence_penalties)
scatter_(add가 아닌)를 사용하므로 같은 토큰이 여러 번 나와도 penalty가 누적되지 않는다.
Min New Tokens: EOS 차단
최소 토큰 수를 채우기 전에 EOS 토큰을 선택하지 못하게 막는다.
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
def _apply(self, logits):
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
logits[mask] += self.stop_token_penalties[mask]
stop_token_penalties는 EOS와 stop 토큰 위치에 -inf를 미리 채워둔 텐서다. 최소 길이에 도달하기 전까지 이 값을 더하여 해당 토큰 선택을 차단한다.
세 가지 페널티 비교
토큰 "hello"가 3번 출력된 상태:
Repetition (penalty=1.2):
logit=2.0 → 2.0 / 1.2 = 1.67 (3번이든 1번이든 동일)
Frequency (penalty=0.5):
logit=2.0 → 2.0 - (0.5 × 3) = 0.5 (횟수에 비례)
Presence (penalty=0.5):
logit=2.0 → 2.0 - 0.5 = 1.5 (등장 여부만)
| 특성 | Repetition | Frequency | Presence |
|---|---|---|---|
| 방식 | 곱셈/나눗셈 | 뺄셈 | 뺄셈 |
| 빈도 반영 | X | O | X |
| 초기값 | 1.0 | 0.0 | 0.0 |
| 누적 연산 | scatter_ | scatter_add_ | scatter_ |
Speculative Decoding 지원
apply() 메서드는 speculative decoding의 draft 토큰 레이아웃을 지원한다.
def apply(self, logits, repeat=None):
if repeat is None:
for penalizer in self.penalizers.values():
penalizer.apply(logits)
else:
bs = logits.shape[0] // repeat
additive = torch.zeros((bs, logits.shape[1]), ...)
self.accumulate_additive_penalties(additive)
logits.add_(torch.repeat_interleave(additive, repeat, dim=0))
accumulated = self.accumulate_scaling_penalties()
if accumulated is not None:
expanded = torch.repeat_interleave(accumulated, repeat, dim=0)
apply_scaling_penalties(logits, expanded)
Additive 페널티(frequency, presence)와 scaling 페널티(repetition)를 분리하여 repeat_interleave로 확장한다.
설계 근거
| 설계 선택 | 이유 |
|---|---|
| Orchestrator 패턴 | 페널티 종류가 확장 가능하며, 필요한 것만 prepare하여 메모리 절약 |
| weakref 사용 | ScheduleBatch 순환 참조 방지로 GC 누수 차단 |
is_multiplicative 분류 |
speculative decoding에서 additive/scaling 페널티 분리 처리 필요 |
prepare_if_required |
배치 내 모든 요청이 penalty=0이면 텐서 할당 자체를 건너뜀 |
관련 포스트
- Sampler: logits에서 토큰까지 - PenaltyLib가 적용된 logits를 소비하는 Sampler
- Sampling Parameters: 전체 파라미터 정리 - 페널티 파라미터의 유효 범위와 기본값
- Custom Logit Processor: 사용자 정의 로짓 처리 - 페널티 외 사용자 정의 로짓 변환
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Sampling Parameters: 전체 샘플링 파라미터 정리
- 현재글 : [SGLang] PenaltyLib: 반복/빈도/존재 페널티 구현
- 다음글 [SGLang] Custom Logit Processor: 사용자 정의 로짓 처리
댓글