본문으로 건너뛰기

[SGLang] XGrammar: JSON/Regex 제약 백엔드

들어가며

XGrammar는 SGLang의 기본(default) Constrained Decoding 백엔드다. JSON Schema, Regex, EBNF, Structural Tag를 모두 지원하며, 비트마스크 기반의 GPU 친화적인 토큰 필터링을 제공한다. xgrammar 라이브러리를 기반으로 구현되어 있으며, 파일 경로는 python/sglang/srt/constrained/xgrammar_backend.py다.

구조도

┌──────────────────────────────────────────────┐
│           XGrammarGrammarBackend             │
│  ┌──────────────┐  ┌─────────────────────┐   │
│  │GrammarCompiler│  │ override_stop_tokens│   │
│  │(TokenizerInfo)│  │ any_whitespace      │   │
│  └──────┬───────┘  └─────────────────────┘   │
│         │                                    │
│  dispatch_json() ──► compile_json_schema()   │
│  dispatch_regex() ─► compile_regex()         │
│  dispatch_ebnf() ──► compile_grammar()       │
│  dispatch_structural_tag() ─► compile_structural_tag()│
│         │                                    │
│         ▼                                    │
│  ┌─────────────────────┐                     │
│  │  CompiledGrammar    │  (컴파일된 문법)     │
│  └─────────┬───────────┘                     │
│            ▼                                 │
│  ┌─────────────────────┐                     │
│  │  GrammarMatcher     │  (런타임 상태 머신)  │
│  └─────────┬───────────┘                     │
│            ▼                                 │
│  ┌─────────────────────┐                     │
│  │ XGrammarGrammar     │  (SGLang 인터페이스) │
│  └─────────────────────┘                     │
└──────────────────────────────────────────────┘

핵심 코드 분석

1. 백엔드 초기화와 TokenizerInfo

class XGrammarGrammarBackend(BaseGrammarBackend):
    def __init__(self, tokenizer, vocab_size, model_eos_token_ids=None, any_whitespace=True):
        super().__init__()
        if hasattr(tokenizer, "init_xgrammar"):
            tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar()
        else:
            tokenizer_info = TokenizerInfo.from_huggingface(
                tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids
            )
            override_stop_tokens = None

        self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
        self.vocab_size = vocab_size

XGrammar는 토크나이저의 전체 어휘를 분석하여 TokenizerInfo를 생성한다. 이 정보는 각 토큰이 어떤 바이트 시퀀스에 매핑되는지를 담고 있어, 문법 규칙과 토큰을 정확히 대응시킬 수 있다. 특수 토크나이저(예: ChatGLM)는 init_xgrammar() 메서드를 통해 별도 처리한다.

2. JSON Schema 디스패치

def dispatch_json(self, key_string: str) -> BaseGrammarObject:
    try:
        if key_string == "$$ANY$$":
            ctx = self.grammar_compiler.compile_builtin_json_grammar()
        else:
            ctx = self.grammar_compiler.compile_json_schema(
                schema=key_string, any_whitespace=self.any_whitespace
            )
    except (RuntimeError, json.decoder.JSONDecodeError, UnicodeDecodeError) as e:
        return InvalidGrammarObject(str(e))
    return self._from_context(ctx, key_string, GrammarStats(dispatch_type="json"))

"$$ANY$$" 키워드는 모든 유효한 JSON을 허용하는 내장 문법을 사용한다. 일반 JSON Schema의 경우 compile_json_schema()가 스키마를 파싱하여 PDA(Pushdown Automaton) 형태의 CompiledGrammar를 생성한다. any_whitespace=True이면 JSON 키 사이의 공백을 유연하게 허용한다.

3. GrammarMatcher 생성과 토큰 수락

def _from_context(self, ctx: CompiledGrammar, key_string, grammar_stats):
    matcher = GrammarMatcher(
        ctx,
        max_rollback_tokens=MAX_ROLLBACK_TOKENS,  # 200
        override_stop_tokens=self.override_stop_tokens,
    )
    return XGrammarGrammar(matcher, self.vocab_size, ctx, ...)

GrammarMatcherCompiledGrammar 위에서 동작하는 런타임 상태 머신이다. max_rollback_tokens=200은 speculative decoding에서 잘못된 토큰을 되돌릴 수 있는 최대 범위를 지정한다.

class XGrammarGrammar(BaseGrammarObject):
    def accept_token(self, token: int):
        if not self.is_terminated():
            accepted = self.matcher.accept_token(token)
            if not accepted:
                raise ValueError(
                    f"Tokens not accepted: {token}\n"
                    f"Accepted tokens: {self.accepted_tokens}\n"
                    f"Key string: {self.key_string}"
                )
            else:
                self.accepted_tokens.append(token)

매 토큰 생성 후 accept_token()으로 문법 상태를 전이한다. 문법에 위배되는 토큰이 들어오면 즉시 예외를 발생시킨다.

4. 비트마스크 기반 토큰 필터링

def allocate_vocab_mask(self, vocab_size, batch_size, device):
    return allocate_token_bitmask(batch_size, vocab_size)

def fill_vocab_mask(self, vocab_mask, idx):
    self.matcher.fill_next_token_bitmask(vocab_mask, idx)

def apply_vocab_mask(self, logits, vocab_mask):
    if logits.device.type in {"cuda", "npu", "xpu", "musa"}:
        if _is_hip:
            apply_token_bitmask_inplace_cuda(logits, vocab_mask)
        else:
            apply_token_bitmask_inplace_triton(logits, vocab_mask)

이것이 Constrained Decoding의 핵심 경로다. 3단계로 구성된다:

  1. allocate: 배치 크기 x 어휘 크기의 비트마스크 텐서를 할당한다
  2. fill: 현재 문법 상태에서 허용되는 토큰을 비트마스크에 기록한다
  3. apply: 비트마스크를 logits에 적용하여 금지된 토큰의 확률을 -inf로 설정한다

비트마스크는 Bool 텐서 대비 32배 메모리 효율적이다. AMD GPU(HIP)에서는 sgl_kernel의 CUDA 구현을, NVIDIA GPU에서는 Triton 구현을 사용한다.

5. Rollback과 Jump-Forward

def rollback(self, k: int):
    self.matcher.rollback(k)
    self.accepted_tokens = self.accepted_tokens[:-k]

def try_jump_forward(self, tokenizer):
    s = self.matcher.find_jump_forward_string()
    if s:
        return [], s
    return None

rollback()은 speculative decoding에서 거부된 토큰을 되돌릴 때 사용한다. try_jump_forward()는 현재 문법 상태에서 다음에 올 수 있는 문자열이 하나뿐일 때, 해당 문자열을 미리 건너뛰어 디코딩 단계를 절약한다.

6. Structural Tag 처리

def dispatch_structural_tag(self, key_string: str):
    structural_tag = json.loads(key_string)
    if is_legacy_structural_tag(structural_tag):
        self._sanitize_structural_tag_structures(structural_tag)
        tags = [
            StructuralTagItem(
                begin=structure["begin"],
                schema=json.dumps(structure["schema"]),
                end=structure["end"],
            )
            for structure in structural_tag["structures"]
        ]
        ctx = self.grammar_compiler.compile_structural_tag(
            tags, structural_tag["triggers"]
        )
    else:
        format_dict = structural_tag.get("format")
        if isinstance(format_dict, dict):
            self._sanitize_structural_format(format_dict)
        ctx = self.grammar_compiler.compile_structural_tag(key_string)

Structural Tag는 function calling에서 사용된다. <tool_call> 같은 begin/end 태그 사이에 JSON Schema 제약을 적용한다. 레거시 포맷과 새 포맷을 모두 지원하며, _sanitize_structural_format()이 누락된 json_schema 필드를 빈 스키마로 채운다.

설계 근거

왜 비트마스크인가? 어휘 크기가 128K인 모델에서 Bool 텐서는 배치당 128KB를 소비한다. 비트마스크는 이를 4KB로 줄인다. GPU 메모리 대역폭이 핵심 병목인 LLM 추론에서 이 차이는 의미 있다.

왜 CompiledGrammar과 GrammarMatcher를 분리하는가? 동일한 JSON Schema로 여러 요청이 들어올 때 CompiledGrammar는 한 번만 생성하고, 요청별로 독립적인 GrammarMatcher를 만들어 상태를 관리한다. 컴파일 비용은 높지만 매칭 비용은 낮다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글