본문으로 건너뛰기

[SGLang] Custom Logit Processor: 사용자 정의 로짓 처리

들어가며

표준 샘플링 파라미터로 충분하지 않을 때가 있다. 특정 토큰을 금지하거나, thinking budget을 제어하거나, n-gram 반복을 차단하는 등의 요구에 SGLang은 Custom Logit Processor를 제공한다. 사용자가 Python 클래스를 정의하면 서빙 시점에 logits에 직접 적용된다.

이 글에서는 python/sglang/srt/sampling/custom_logit_processor.py를 중심으로 Custom Logit Processor의 구조와 내장 구현체를 분석한다.

전체 구조도

클라이언트                         서버
┌──────────────┐           ┌───────────────────────────┐
│ MyProcessor  │──dill────▶│ _cache_from_str()         │
│   .to_str()  │  직렬화   │   │                       │
└──────────────┘           │   ▼                       │
                           │ CustomLogitProcessor      │
                           │   .__call__(logits, params)│
                           │   │                       │
                           │   ▼                       │
                           │ Sampler._preprocess_logits │
                           │   apply_custom_logit_     │
                           │   processor(logits, info) │
                           └───────────────────────────┘

핵심 코드 분석

추상 기반 클래스

CustomLogitProcessor는 ABC로 __call__ 메서드를 정의한다.

class CustomLogitProcessor(ABC):
    @abstractmethod
    def __call__(
        self,
        logits: torch.Tensor,
        custom_param_list: Optional[List[Dict[str, Any]]] = None,
    ) -> torch.Tensor:
        raise NotImplementedError

logits는 (batch_subset, vocab_size) 형상이며, custom_param_list는 해당 배치 요소들의 사용자 정의 파라미터 리스트다. 반환된 텐서가 원본 logits를 대체한다.

직렬화와 역직렬화

클라이언트에서 서버로 프로세서를 전달하기 위해 dill을 사용한다.

@classmethod
def to_str(cls) -> str:
    return json.dumps({"callable": dill.dumps(cls).hex()})

@classmethod
def from_str(cls, json_str: str):
    return _cache_from_str(json_str)()

@lru_cache(maxsize=None)
def _cache_from_str(json_str: str):
    data = orjson.loads(json_str)
    return dill.loads(bytes.fromhex(data["callable"]))

to_str()로 클래스를 직렬화하고, from_str()으로 역직렬화한다. lru_cache로 동일 프로세서의 반복 역직렬화를 방지한다.

배치 적용 로직

Sampler에서 호출되는 apply_custom_logit_processor 함수는 배치 마스크 기반으로 동작한다.

def apply_custom_logit_processor(logits, sampling_batch_info, num_tokens_in_batch=1):
    for _, (processor, batch_mask) in sampling_batch_info.custom_logit_processor.items():
        batch_indices = batch_mask.nonzero(as_tuple=True)[0]
        batch_mask = torch.repeat_interleave(batch_mask, num_tokens_in_batch)
        logits[batch_mask] = processor(
            logits[batch_mask],
            [sampling_batch_info.custom_params[i] for i in batch_indices],
        )

같은 프로세서를 사용하는 요청들을 마스크로 그룹화하여 한 번에 처리한다. num_tokens_in_batch는 speculative decoding에서 요청당 복수 토큰을 처리할 때 사용된다.

내장 구현체 분석

DisallowedTokensLogitsProcessor

특정 토큰을 완전히 차단한다.

class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
    def __call__(self, logits, custom_param_list=None):
        disallowed_token_ids = custom_param_list[0]["token_ids"]
        logits[..., disallowed_token_ids] = -float("inf")
        return logits

차단할 토큰의 logit을 -inf로 설정하여 softmax 후 확률을 0으로 만든다.

ThinkingBudgetLogitProcessor

Thinking 모델(Qwen3, DeepSeek-R1, GLM-4)의 사고 길이를 제어한다.

class ThinkingBudgetLogitProcessor(CustomLogitProcessor):
    THINKING_START_TOKEN_ID: int
    THINKING_END_TOKEN_ID: int
    NEW_LINE_TOKEN_ID: int

    def __call__(self, logits, custom_param_list):
        for i, param_dict in enumerate(custom_param_list):
            thinking_budget = param_dict.get("thinking_budget")
            req = param_dict.get("__req__")
            cur_ids = [*req.origin_input_ids, *req.output_ids]

            if self.THINKING_START_TOKEN_ID not in cur_ids \
               or self.THINKING_END_TOKEN_ID in cur_ids:
                continue

            start_index = cur_ids.index(self.THINKING_START_TOKEN_ID)
            num_tokens_after_start = len(cur_ids) - start_index - 1

            if num_tokens_after_start >= thinking_budget:
                if not req.output_ids or req.output_ids[-1] != self.NEW_LINE_TOKEN_ID:
                    logits[i, :] = -float("inf")
                    logits[i, self.NEW_LINE_TOKEN_ID] = 0.0
                else:
                    logits[i, :] = -float("inf")
                    logits[i, self.THINKING_END_TOKEN_ID] = 0.0
        return logits

사고 토큰 수가 budget을 초과하면 먼저 줄바꿈 토큰을 강제하고, 그 다음 thinking 종료 토큰을 강제한다. 모델별 서브클래스가 토큰 ID를 정의한다.

class Qwen3ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
    THINKING_START_TOKEN_ID = 151667
    THINKING_END_TOKEN_ID = 151668
    NEW_LINE_TOKEN_ID = 198

class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
    THINKING_START_TOKEN_ID = 128798
    THINKING_END_TOKEN_ID = 128799
    NEW_LINE_TOKEN_ID = 201

DeepseekOCRNoRepeatNGramLogitProcessor

OCR 출력에서 n-gram 반복을 슬라이딩 윈도우 내에서 차단한다.

class DeepseekOCRNoRepeatNGramLogitProcessor(CustomLogitProcessor):
    def __call__(self, logits, custom_param_list=None):
        for batch_idx, params in enumerate(custom_param_list):
            ngram_size = int(params.get("ngram_size") or 0)
            window_size = int(params.get("window_size") or 0)
            sequence = req.origin_input_ids + req.output_ids

            if ngram_size > 1:
                current_prefix = tuple(sequence[-(ngram_size - 1):])

            banned_tokens = set()
            for idx in range(search_start, search_end):
                ngram = sequence[idx : idx + ngram_size]
                if tuple(ngram[:-1]) == current_prefix:
                    banned_tokens.add(ngram[-1])

            banned_tokens.difference_update(whitelist)
            logits[batch_idx, list(banned_tokens)] = -float("inf")
        return logits

현재 접두사와 일치하는 n-gram의 마지막 토큰을 차단하되, whitelist에 있는 토큰은 허용한다.

보안 고려사항

dill을 통한 임의 코드 역직렬화는 보안 위험이 있다. SGLang은 서버 시작 시 --enable-custom-logit-processor 플래그로 이 기능을 명시적으로 활성화해야 한다. 프로덕션 환경에서는 신뢰할 수 있는 클라이언트만 사용하도록 제한해야 한다.

설계 근거

설계 선택 이유
ABC + dill 직렬화 클라이언트에서 임의 Python 로직을 서버에 전달 가능
lru_cache 역직렬화 동일 프로세서의 반복 역직렬화 비용 제거
batch_mask 기반 적용 같은 프로세서를 사용하는 요청을 그룹화하여 단일 호출
__req__ 파라미터 프로세서가 현재까지의 출력 토큰을 참조할 수 있음
모델별 서브클래스 토큰 ID만 다르고 로직은 동일한 패턴을 코드 재사용

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글