[SGLang] Custom Logit Processor: 사용자 정의 로짓 처리
들어가며
표준 샘플링 파라미터로 충분하지 않을 때가 있다. 특정 토큰을 금지하거나, thinking budget을 제어하거나, n-gram 반복을 차단하는 등의 요구에 SGLang은 Custom Logit Processor를 제공한다. 사용자가 Python 클래스를 정의하면 서빙 시점에 logits에 직접 적용된다.
이 글에서는 python/sglang/srt/sampling/custom_logit_processor.py를 중심으로 Custom Logit Processor의 구조와 내장 구현체를 분석한다.
전체 구조도
클라이언트 서버
┌──────────────┐ ┌───────────────────────────┐
│ MyProcessor │──dill────▶│ _cache_from_str() │
│ .to_str() │ 직렬화 │ │ │
└──────────────┘ │ ▼ │
│ CustomLogitProcessor │
│ .__call__(logits, params)│
│ │ │
│ ▼ │
│ Sampler._preprocess_logits │
│ apply_custom_logit_ │
│ processor(logits, info) │
└───────────────────────────┘
핵심 코드 분석
추상 기반 클래스
CustomLogitProcessor는 ABC로 __call__ 메서드를 정의한다.
class CustomLogitProcessor(ABC):
@abstractmethod
def __call__(
self,
logits: torch.Tensor,
custom_param_list: Optional[List[Dict[str, Any]]] = None,
) -> torch.Tensor:
raise NotImplementedError
logits는 (batch_subset, vocab_size) 형상이며, custom_param_list는 해당 배치 요소들의 사용자 정의 파라미터 리스트다. 반환된 텐서가 원본 logits를 대체한다.
직렬화와 역직렬화
클라이언트에서 서버로 프로세서를 전달하기 위해 dill을 사용한다.
@classmethod
def to_str(cls) -> str:
return json.dumps({"callable": dill.dumps(cls).hex()})
@classmethod
def from_str(cls, json_str: str):
return _cache_from_str(json_str)()
@lru_cache(maxsize=None)
def _cache_from_str(json_str: str):
data = orjson.loads(json_str)
return dill.loads(bytes.fromhex(data["callable"]))
to_str()로 클래스를 직렬화하고, from_str()으로 역직렬화한다. lru_cache로 동일 프로세서의 반복 역직렬화를 방지한다.
배치 적용 로직
Sampler에서 호출되는 apply_custom_logit_processor 함수는 배치 마스크 기반으로 동작한다.
def apply_custom_logit_processor(logits, sampling_batch_info, num_tokens_in_batch=1):
for _, (processor, batch_mask) in sampling_batch_info.custom_logit_processor.items():
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
batch_mask = torch.repeat_interleave(batch_mask, num_tokens_in_batch)
logits[batch_mask] = processor(
logits[batch_mask],
[sampling_batch_info.custom_params[i] for i in batch_indices],
)
같은 프로세서를 사용하는 요청들을 마스크로 그룹화하여 한 번에 처리한다. num_tokens_in_batch는 speculative decoding에서 요청당 복수 토큰을 처리할 때 사용된다.
내장 구현체 분석
DisallowedTokensLogitsProcessor
특정 토큰을 완전히 차단한다.
class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
def __call__(self, logits, custom_param_list=None):
disallowed_token_ids = custom_param_list[0]["token_ids"]
logits[..., disallowed_token_ids] = -float("inf")
return logits
차단할 토큰의 logit을 -inf로 설정하여 softmax 후 확률을 0으로 만든다.
ThinkingBudgetLogitProcessor
Thinking 모델(Qwen3, DeepSeek-R1, GLM-4)의 사고 길이를 제어한다.
class ThinkingBudgetLogitProcessor(CustomLogitProcessor):
THINKING_START_TOKEN_ID: int
THINKING_END_TOKEN_ID: int
NEW_LINE_TOKEN_ID: int
def __call__(self, logits, custom_param_list):
for i, param_dict in enumerate(custom_param_list):
thinking_budget = param_dict.get("thinking_budget")
req = param_dict.get("__req__")
cur_ids = [*req.origin_input_ids, *req.output_ids]
if self.THINKING_START_TOKEN_ID not in cur_ids \
or self.THINKING_END_TOKEN_ID in cur_ids:
continue
start_index = cur_ids.index(self.THINKING_START_TOKEN_ID)
num_tokens_after_start = len(cur_ids) - start_index - 1
if num_tokens_after_start >= thinking_budget:
if not req.output_ids or req.output_ids[-1] != self.NEW_LINE_TOKEN_ID:
logits[i, :] = -float("inf")
logits[i, self.NEW_LINE_TOKEN_ID] = 0.0
else:
logits[i, :] = -float("inf")
logits[i, self.THINKING_END_TOKEN_ID] = 0.0
return logits
사고 토큰 수가 budget을 초과하면 먼저 줄바꿈 토큰을 강제하고, 그 다음 thinking 종료 토큰을 강제한다. 모델별 서브클래스가 토큰 ID를 정의한다.
class Qwen3ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
THINKING_START_TOKEN_ID = 151667
THINKING_END_TOKEN_ID = 151668
NEW_LINE_TOKEN_ID = 198
class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
THINKING_START_TOKEN_ID = 128798
THINKING_END_TOKEN_ID = 128799
NEW_LINE_TOKEN_ID = 201
DeepseekOCRNoRepeatNGramLogitProcessor
OCR 출력에서 n-gram 반복을 슬라이딩 윈도우 내에서 차단한다.
class DeepseekOCRNoRepeatNGramLogitProcessor(CustomLogitProcessor):
def __call__(self, logits, custom_param_list=None):
for batch_idx, params in enumerate(custom_param_list):
ngram_size = int(params.get("ngram_size") or 0)
window_size = int(params.get("window_size") or 0)
sequence = req.origin_input_ids + req.output_ids
if ngram_size > 1:
current_prefix = tuple(sequence[-(ngram_size - 1):])
banned_tokens = set()
for idx in range(search_start, search_end):
ngram = sequence[idx : idx + ngram_size]
if tuple(ngram[:-1]) == current_prefix:
banned_tokens.add(ngram[-1])
banned_tokens.difference_update(whitelist)
logits[batch_idx, list(banned_tokens)] = -float("inf")
return logits
현재 접두사와 일치하는 n-gram의 마지막 토큰을 차단하되, whitelist에 있는 토큰은 허용한다.
보안 고려사항
dill을 통한 임의 코드 역직렬화는 보안 위험이 있다. SGLang은 서버 시작 시 --enable-custom-logit-processor 플래그로 이 기능을 명시적으로 활성화해야 한다. 프로덕션 환경에서는 신뢰할 수 있는 클라이언트만 사용하도록 제한해야 한다.
설계 근거
| 설계 선택 | 이유 |
|---|---|
| ABC + dill 직렬화 | 클라이언트에서 임의 Python 로직을 서버에 전달 가능 |
| lru_cache 역직렬화 | 동일 프로세서의 반복 역직렬화 비용 제거 |
| batch_mask 기반 적용 | 같은 프로세서를 사용하는 요청을 그룹화하여 단일 호출 |
__req__ 파라미터 |
프로세서가 현재까지의 출력 토큰을 참조할 수 있음 |
| 모델별 서브클래스 | 토큰 ID만 다르고 로직은 동일한 패턴을 코드 재사용 |
관련 포스트
- Sampler: logits에서 토큰까지 - Custom Logit Processor가 적용되는 Sampler 파이프라인
- Sampling Parameters: 전체 파라미터 정리 - custom_params 파라미터의 정의
- PenaltyLib: 반복/빈도/존재 페널티 - 내장 페널티와 Custom Logit Processor의 차이
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] PenaltyLib: 반복/빈도/존재 페널티 구현
- 현재글 : [SGLang] Custom Logit Processor: 사용자 정의 로짓 처리
- 다음글 [SGLang] Multimodal 처리 파이프라인 개요: Vision/Audio/Video 통합
댓글