본문으로 건너뛰기

[SGLang] Reasoner Grammar: 추론 체인 제약 생성

들어가며

DeepSeek-R1, QwQ 같은 추론(reasoning) 모델은 <think>...</think> 블록 안에서 Chain-of-Thought를 생성한 뒤, 최종 답변을 출력한다. 이때 문제가 발생한다: JSON Schema 제약을 걸면 추론 단계의 자유 텍스트도 제약을 받아, 모델이 "생각"할 수 없게 된다.

SGLang의 ReasonerGrammarBackend는 이 문제를 해결한다. think_end_id 토큰(예: </think>)을 기준으로 추론 단계와 응답 단계를 분리하여, 추론 중에는 제약을 비활성화하고 응답 단계에서만 문법 제약을 적용한다.

소스 파일: python/sglang/srt/constrained/reasoner_grammar_backend.py

구조도

┌──────────────────────────────────────────────────┐
            ReasonerGrammarBackend                 
  (Decorator Pattern: 기존 백엔드를 감싸는 래퍼)    
                                                  
  grammar_backend: BaseGrammarBackend             
  think_end_id: int                               
                                                  
  _init_value_dispatch()                          
    └─► grammar_backend._init_value_dispatch()    
          └─► ReasonerGrammarObject(ret, ...)     
└──────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────┐
            ReasonerGrammarObject                  
                                                  
  grammar: BaseGrammarObject  (실제 제약 객체)     
  think_end_id: int                               
  tokens_after_think_end: int                     
                                                  
  상태 전이:                                       
  ┌──────────┐  think_end_id  ┌──────────────┐    
   추론 단계  ────────────►  응답 단계         
   state=-1                 state=0,1,2..│    
   제약 OFF                 제약 ON          
  └──────────┘               └──────────────┘    
└──────────────────────────────────────────────────┘

핵심 코드 분석

1. Decorator Pattern 적용

class ReasonerGrammarBackend(BaseGrammarBackend):
    def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
        super().__init__()
        self.grammar_backend = grammar_backend
        self.think_end_id = think_end_id

ReasonerGrammarBackend는 기존 백엔드(XGrammar, Outlines, LLGuidance)를 감싸는 래퍼다. base_grammar_backend.pycreate_grammar_backend()에서 조건부로 적용된다:

# base_grammar_backend.py
if server_args.reasoning_parser and think_end_id is not None:
    from sglang.srt.constrained.reasoner_grammar_backend import ReasonerGrammarBackend
    grammar_backend = ReasonerGrammarBackend(grammar_backend, think_end_id)

reasoning_parser가 설정되어 있고 모델에 think_end_id가 존재할 때만 래핑한다. 기존 백엔드의 모든 dispatch 메서드를 그대로 위임하되, 반환되는 Grammar 객체를 ReasonerGrammarObject로 감싼다.

2. 값 디스패치와 Invalid Grammar 처리

def _init_value_dispatch(self, key, reasoning):
    ret = self.grammar_backend._init_value_dispatch(key, reasoning)
    if ret is None or isinstance(ret, InvalidGrammarObject):
        return ret  # 래핑하지 않음
    obj = ReasonerGrammarObject(ret, self.think_end_id)
    obj.maybe_init_reasoning(reasoning)
    return obj

유효하지 않은 문법(InvalidGrammarObject)은 래핑하지 않고 그대로 반환한다. 스케줄러가 isinstance(req.grammar, InvalidGrammarObject)로 실패를 감지해야 하기 때문이다. 유효한 문법만 ReasonerGrammarObject로 감싸고, maybe_init_reasoning()으로 추론 모드를 초기화한다.

3. 상태 머신: tokens_after_think_end

class ReasonerGrammarObject(BaseGrammarObject):
    def __init__(self, grammar, think_end_id):
        super().__init__()
        self.grammar = grammar
        self.think_end_id = think_end_id
        self.tokens_after_think_end = -1
        # -1: 추론 단계 (thinking 진행 중)
        #  0: think_end_id를 방금 받음
        # +N: 응답 단계 (think 종료 후 N개 토큰)

상태는 정수 하나로 관리된다:

  • -1: 추론 단계. <think> 블록 내부. 문법 제약이 비활성화된다.
  • 0: think_end_id 토큰을 방금 수락한 시점. 이 다음 토큰부터 제약이 적용된다.
  • 양수 N: 응답 단계. think_end_id 이후 N개 토큰이 생성됨. 문법 제약이 활성화된다.
def maybe_init_reasoning(self, reasoning: bool):
    self.tokens_after_think_end = -1 if reasoning else 0

reasoning=True이면 추론 단계(-1)에서 시작하고, False이면 바로 응답 단계(0)에서 시작한다. 이는 모델이 추론을 하지 않는 일반 요청에서도 ReasonerGrammarObject가 올바르게 동작하도록 한다.

4. 상태 전이와 토큰 수락

def transfer_state(self, token: int) -> int:
    if self.tokens_after_think_end == -1 and token == self.think_end_id:
        self.tokens_after_think_end = 0
    elif self.tokens_after_think_end >= 0:
        self.tokens_after_think_end += 1

def accept_token(self, token: int):
    if self.tokens_after_think_end >= 0:
        self.grammar.accept_token(token)  # 응답 단계에서만 문법에 전달
    self.transfer_state(token)

accept_token()의 실행 순서가 중요하다. 먼저 현재 상태를 확인하여 응답 단계(>= 0)이면 내부 문법에 토큰을 전달한다. 그 후에 상태를 전이한다. 이렇게 하면 think_end_id 토큰 자체는 문법에 전달되지 않는다(수락 시점에는 아직 -1).

5. 마스크 조건부 적용

def fill_vocab_mask(self, vocab_mask, idx):
    if self.tokens_after_think_end >= 0:
        self.grammar.fill_vocab_mask(vocab_mask, idx)

추론 단계(tokens_after_think_end == -1)에서는 fill_vocab_mask()가 호출되지 않으므로, 마스크가 모두 허용 상태로 유지된다. 모델은 자유롭게 "생각"할 수 있다. 응답 단계에서만 내부 문법의 마스크가 적용되어 JSON, Regex 등의 제약이 걸린다.

6. Rollback 처리

def rollback_state(self):
    if self.tokens_after_think_end == 0:
        self.tokens_after_think_end = -1  # think_end_id 이전으로 복귀
    elif self.tokens_after_think_end > 0:
        self.tokens_after_think_end -= 1

def rollback(self, k):
    steps_after_think = min(k, self.tokens_after_think_end)
    if steps_after_think > 0:
        self.grammar.rollback(steps_after_think)

    for _ in range(k):
        self.rollback_state()

Rollback은 두 부분으로 나뉜다. 먼저 응답 단계에서 생성된 토큰 수(steps_after_think)만큼만 내부 문법을 rollback한다. 추론 단계의 토큰은 문법에 전달된 적이 없으므로 rollback할 필요가 없다. 그 후 상태 카운터를 k번 감소시킨다.

예: tokens_after_think_end = 3이고 k = 5이면, 문법은 3번만 rollback하고, 상태 카운터는 5번 감소하여 -1(추론 단계)로 돌아간다.

7. 위임 메서드들

def allocate_vocab_mask(self, vocab_size, batch_size, device):
    return self.grammar.allocate_vocab_mask(vocab_size, batch_size, device)

@property
def apply_vocab_mask(self):
    return self.grammar.apply_vocab_mask

def copy(self):
    return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)

def try_jump_forward(self, tokenizer):
    return self.grammar.try_jump_forward(tokenizer)

마스크 할당, 적용, 복사, Jump-Forward 등은 내부 문법 객체에 그대로 위임한다. apply_vocab_mask@property로 구현된 이유는, 이 메서드가 @staticmethod인 백엔드(XGrammar, Outlines)와의 호환성을 유지하면서 동적으로 내부 문법의 구현을 반환하기 위함이다.

설계 근거

왜 Decorator Pattern인가? 추론 모델 지원을 위해 기존 3개 백엔드를 각각 수정하면 코드 중복이 발생한다. 래퍼 패턴으로 추론/응답 분리 로직을 한 곳에 집중시키고, 어떤 백엔드든 동일하게 감쌀 수 있다.

왜 카운터 방식인가? 불리언 플래그 대신 정수 카운터를 사용하면 rollback 시 정확히 몇 토큰을 문법에서 되돌려야 하는지 계산할 수 있다. 이는 speculative decoding과의 호환성에 필수적이다.

accept 전에 전달, 전이는 후에: think_end_id 토큰은 문법 구조의 일부가 아니다(JSON에 </think>가 포함되면 안 된다). 따라서 상태 전이를 토큰 전달 이후에 수행하여, think_end_id가 문법에 들어가는 것을 자연스럽게 방지한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글