본문으로 건너뛰기

[SGLang] 중간 표현(IR): SglGen, SglSelect, SglExpr의 설계

들어가며

SGLang은 LLM 프로그래밍을 위한 DSL(Domain-Specific Language)이다. 사용자가 작성한 SGL 함수는 직접 실행되지 않고, 먼저 중간 표현(Intermediate Representation, IR)으로 변환된다. 이 IR은 생성(gen), 선택(select), 분기(fork) 같은 연산을 트리 구조의 노드로 표현하며, interpreter나 tracer가 이를 순회하며 실제 백엔드 호출로 변환한다.

컴파일러에서 IR이 소스 코드와 기계어 사이의 추상화 계층인 것처럼, SGLang의 IR은 사용자의 프롬프트 로직과 백엔드 실행 사이의 추상화 계층이다. 이 글에서는 python/sglang/lang/ir.py의 코드를 분석하여 IR 시스템의 설계를 살펴본다.

IR 변환 구조도

SGL 함수가 실행되기까지의 전체 흐름을 도식화하면 다음과 같다.

SGL Code (Python DSL)           IR Tree                    Execution
========================    ==================    ========================

@sgl.function                SglExprList            interpreter.run_program()
def chat(s, question):       +-- SglRoleBegin         |
  s += s.system("...")           ("system")           +-> Backend.generate()
  s += s.user(question)      +-- SglConstantText      |   (OpenAI / SRT / ...)
  s += s.assistant(              ("You are...")        |
    sgl.gen("answer",        +-- SglRoleEnd            +-> Backend.select()
      max_new_tokens=256)        ("system")            |
  )                          +-- SglRoleBegin          +-> Result collection
                                 ("user")                  s["answer"]
                             +-- SglArgument
                                 (question)
                             +-- SglRoleEnd
                                 ("user")
                             +-- SglRoleBegin
                                 ("assistant")
                             +-- SglGen
                                 ("answer", max=256)
                             +-- SglRoleEnd
                                 ("assistant")

SGL 코드의 += 연산자는 SglExpr.__add__를 호출하여 IR 노드를 순서대로 연결한다. 최종적으로 SglExprList가 만들어지고, interpreter가 이 리스트를 순회하며 백엔드 API를 호출한다.

핵심 코드 분석

SglExpr: 모든 IR 노드의 기반 클래스

SglExpr는 모든 IR 노드의 부모 클래스다. 각 노드에 고유 ID를 부여하고, + 연산자 오버로딩으로 노드를 연결하는 핵심 메커니즘을 제공한다.

class SglExpr:
    node_ct = 0

    def __init__(self):
        self.node_id = SglExpr.node_ct
        self.prev_node = None
        self.pid = None
        SglExpr.node_ct += 1

    def __add__(self, other):
        if isinstance(other, str):
            other = SglConstantText(other)
        assert isinstance(other, SglExpr)
        return self.concatenate_ir(self, other)

    def __radd__(self, other):
        if isinstance(other, str):
            other = SglConstantText(other)
        assert isinstance(other, SglExpr), f"{other}"
        return self.concatenate_ir(other, self)

node_ct는 클래스 변수로, 생성되는 모든 노드에 전역적으로 증가하는 고유 ID를 부여한다. __add____radd__가 모두 구현되어 있어 "문자열" + SglExpr 형태도 자연스럽게 처리된다. 문자열은 자동으로 SglConstantText 노드로 래핑된다.

SglExprList: IR 노드의 연결

concatenate_ir 메서드가 노드를 연결하는 방식이 흥미롭다. 두 노드를 무조건 새 리스트로 감싸는 것이 아니라, 이미 SglExprList인 경우 기존 리스트를 확장한다.

def concatenate_ir(self, a, b):
    if isinstance(a, SglExprList):
        if isinstance(b, SglExprList):
            return SglExprList(a.expr_list + b.expr_list)
        else:
            return SglExprList(a.expr_list + [b])
    elif isinstance(b, SglExprList):
        return SglExprList([a] + b.expr_list)
    return SglExprList([a, b])

이 설계 덕분에 s += a; s += b; s += c처럼 여러 노드를 연속으로 추가해도 중첩된 트리가 아닌 평탄한 리스트가 만들어진다. 컴파일러 이론에서 말하는 linearized IR에 해당한다.

SglGen: 텍스트 생성 노드

SglGen은 LLM에게 텍스트 생성을 요청하는 IR 노드다. sampling parameter 전체를 SglSamplingParams 객체로 캡슐화한다.

class SglGen(SglExpr):
    def __init__(
        self,
        name: Optional[str] = None,
        max_new_tokens: Optional[int] = None,
        stop: Optional[Union[str, List[str]]] = None,
        temperature: Optional[float] = None,
        regex: Optional[str] = None,
        json_schema: Optional[str] = None,
        # ... 기타 sampling 파라미터
    ):
        super().__init__()
        self.name = name
        self.sampling_params = SglSamplingParams(
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            regex=regex,
            json_schema=json_schema,
            # ...
        )

name은 생성 결과를 저장할 변수명이다. s += sgl.gen("answer")로 호출하면 나중에 s["answer"]로 결과를 꺼낼 수 있다. regexjson_schema 파라미터는 constrained generation을 지원하여, IR 수준에서 출력 형식 제약을 선언적으로 명시할 수 있게 한다.

SglSelect: 선택 노드

SglSelect는 여러 후보 중 하나를 선택하는 IR 노드다. classification이나 routing 같은 패턴에 사용된다.

class SglSelect(SglExpr):
    def __init__(
        self,
        name: str,
        choices: List[str],
        temperature: float,
        choices_method: ChoicesSamplingMethod,
    ):
        super().__init__()
        self.name = name
        self.choices = choices
        self.temperature = temperature
        self.choices_method = choices_method

choices_methodChoicesSamplingMethod 타입으로, 후보 선택 전략(예: token probability 기반)을 지정한다. SglGen과 달리 SglSamplingParams를 사용하지 않고 자체 파라미터만 보유하는데, select 연산은 새 토큰을 생성하는 것이 아니라 기존 후보의 확률을 비교하는 것이기 때문이다.

SglImage / SglVideo: 멀티모달 노드

멀티모달 입력도 IR 노드로 표현된다.

class SglImage(SglExpr):
    def __init__(self, path: str):
        self.path = path

class SglVideo(SglExpr):
    def __init__(self, path: str, num_frames: int):
        self.path = path
        self.num_frames = num_frames

SglImageSglVideo는 파일 경로만 보유하는 경량 노드다. 실제 이미지/비디오 로딩은 interpreter가 이 노드를 만났을 때 수행한다. IR은 "무엇을 할지"만 기술하고 "어떻게 할지"는 실행 단계에 위임하는 원칙이 여기서도 일관되게 적용된다.

SglFunction: IR에서 실행 계획으로

SglFunction은 사용자가 @sgl.function으로 데코레이트한 함수를 감싸는 래퍼다. IR을 생성하고 실행까지 연결하는 진입점이다.

class SglFunction:
    def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
        self.func = func
        self.bind_arguments = bind_arguments or {}
        argspec = inspect.getfullargspec(func)
        assert argspec.args[0] == "s", 'The first argument must be "s"'
        self.arg_names = argspec.args[1:]

    def run(self, *args, max_new_tokens=128, temperature=1.0, backend=None, **kwargs):
        from sglang.lang.interpreter import run_program
        default_sampling_para = SglSamplingParams(
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            # ...
        )
        backend = backend or global_config.default_backend
        return run_program(self, backend, args, kwargs, default_sampling_para, ...)

run 메서드는 기본 sampling parameter를 SglSamplingParams로 구성한 뒤, interpreter.run_program에 위임한다. 여기서 IR이 실제 백엔드 호출로 변환된다. __call__ 메서드는 tracing scope 유무에 따라 run(즉시 실행) 또는 trace(지연 실행)로 분기한다.

def __call__(self, *args, **kwargs):
    from sglang.lang.tracer import TracingScope
    tracing_scope = TracingScope.get_current_scope()
    if tracing_scope is None:
        return self.run(*args, **kwargs)
    else:
        kwargs["backend"] = tracing_scope.tracer_state.backend
        return self.trace(*args, **kwargs)

이 이중 실행 경로는 같은 SGL 함수를 개발 시에는 eager 모드로, 프로덕션에서는 traced/compiled 모드로 사용할 수 있게 한다.

SglExpr.print_graph_dfs는 IR 그래프를 DFS로 순회하며 텍스트로 출력한다. 디버깅 시 IR 구조를 확인하는 데 유용하다.

def print_graph_dfs(self):
    ret = [""]
    visited = set()
    def dfs_print(x):
        if x is None or x in visited:
            return
        visited.add(x)
        if x.prev_node is not None:
            dfs_print(x.prev_node)
        if isinstance(x, SglExprList):
            for y in x.expr_list:
                dfs_print(y)
        if isinstance(x, (SglFork, SglGetForkItem)):
            ret[0] += f"%{x.node_id} = {x}\n"
        else:
            if x.prev_node is not None:
                ret[0] += f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
            else:
                ret[0] += f"%{x.node_id} = " + str(x) + "\n"
    dfs_print(self)
    return ret[0]

출력 형식이 %0 = Constant('Hello'), %1 = %0 + Gen('answer') 같은 SSA(Static Single Assignment) 스타일이다. 컴파일러의 IR dump와 동일한 패턴을 따른다.

컴파일러 이론과의 비교

SGLang IR의 설계를 전통적인 컴파일러 IR과 비교하면 다음과 같다.

컴파일러 IR 개념 SGLang IR 대응
AST Node SglExpr 및 하위 클래스
Basic Block SglExprList (선형 노드 시퀀스)
SSA Value node_id 기반 고유 식별
Instruction SglGen, SglSelect, SglConstantText
Operand SglSamplingParams, choices
Function SglFunction
Lowering (IR → Machine Code) interpreter.run_program (IR → Backend API Call)
Trace-based JIT TracingScope 기반 지연 실행

특히 concatenate_ir가 중첩 트리가 아닌 평탄한 리스트를 만드는 설계는, 컴파일러에서 tree IR을 linearize하여 basic block으로 만드는 과정과 같다. 이 덕분에 interpreter는 단순한 for 루프로 IR을 순회할 수 있다.

SglForkSglGetForkItem은 병렬 실행 분기를 표현하는 노드로, 컴파일러의 control flow graph에서 branch/merge에 해당한다. prev_node를 통한 의존성 추적은 SSA의 def-use chain과 유사한 역할을 한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글