본문으로 건너뛰기

[SGLang] IO 데이터 구조: 요청에서 응답까지의 직렬화 설계

들어가며

SGLang은 TokenizerManager, Scheduler, DetokenizerManager라는 세 개의 프로세스가 ZMQ IPC로 통신하는 파이프라인 아키텍처를 사용한다. 이 프로세스들 사이를 오가는 데이터는 모두 Python dataclass로 정의되며, pickle을 통해 직렬화/역직렬화된다. 이 데이터 구조들이 바로 io_struct.py에 정의되어 있다.

이 파일은 SGLang 파이프라인의 "프로토콜"이다. 어떤 데이터가 어떤 형태로 전달되는지, 어떤 필드가 선택적이고 어떤 필드가 필수인지, 단일 요청과 배치 요청이 어떻게 구분되는지를 모두 담고 있다. 이 글에서는 python/sglang/srt/managers/io_struct.py를 중심으로 핵심 데이터 구조를 분석한다.

전체 구조

요청이 파이프라인을 통과하면서 데이터 구조가 어떻게 변환되는지 살펴보자.

 Client Request (JSON)
       │
       ▼
 GenerateReqInput / EmbeddingReqInput     ← 원시 요청 (텍스트/ID)
       │  normalize_batch_and_arguments()
       │  TokenizerManager._tokenize_one_request()
       ▼
 TokenizedGenerateReqInput /               ← 토큰화된 요청
 TokenizedEmbeddingReqInput
       │  ZMQ IPC → Scheduler
       ▼
 (Scheduler 내부 처리: Req → ScheduleBatch → forward pass)
       │
       ▼
 BatchTokenIDOutput                        ← Scheduler → Detokenizer
       │  DetokenizerManager._decode_batch_token_id_output()
       ▼
 BatchStrOutput                            ← Detokenizer → Tokenizer
       │  TokenizerManager._handle_batch_output()
       ▼
 Client Response (JSON)

각 단계마다 데이터 구조가 변환된다. 원시 텍스트 → 토큰 ID → 출력 토큰 ID → 출력 텍스트의 흐름이다.

핵심 코드 분석

기반 클래스: BaseReq와 BaseBatchReq

모든 요청/응답 데이터 구조의 기반이 되는 두 개의 추상 클래스가 있다.

@dataclass
class BaseReq(ABC):
    rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
    http_worker_ipc: Optional[str] = field(default=None, kw_only=True)

    def regenerate_rid(self):
        """Generate a new request ID and return it."""
        if isinstance(self.rid, list):
            self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))]
        else:
            self.rid = uuid.uuid4().hex
        return self.rid

@dataclass
class BaseBatchReq(ABC):
    rids: Optional[List[str]] = field(default=None, kw_only=True)
    http_worker_ipcs: Optional[List[str]] = field(default=None, kw_only=True)

BaseReq는 단일 요청용이고, BaseBatchReq는 배치 요청/응답용이다. 두 클래스 모두 rid(request ID)와 http_worker_ipc(응답 라우팅용 IPC 이름)를 공통으로 갖는다. riduuid4().hex로 생성되며, 파이프라인 전체에서 요청을 추적하는 유일한 식별자 역할을 한다. http_worker_ipc는 multi-worker 모드에서 DetokenizerManager가 결과를 올바른 TokenizerWorker에게 돌려보내는 데 사용된다.

GenerateReqInput: 원시 요청

클라이언트로부터 받는 생성 요청의 데이터 구조다. 이 클래스가 io_struct.py에서 가장 복잡한 구조체다.

@dataclass
class GenerateReqInput(BaseReq):
    # The input prompt. It can be a single prompt or a batch of prompts.
    text: Optional[Union[List[str], str]] = None
    # The token ids for text; one can specify either text or input_ids
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The embeddings for input_ids
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
    # The image input.
    image_data: Optional[MultimodalDataInputFormat] = None
    video_data: Optional[MultimodalDataInputFormat] = None
    audio_data: Optional[MultimodalDataInputFormat] = None
    # The sampling_params.
    sampling_params: Optional[Union[List[Dict], Dict]] = None
    # Whether to return logprobs.
    return_logprob: Optional[Union[List[bool], bool]] = None
    # Whether to stream output.
    stream: bool = False

거의 모든 필드가 Union[List[...], ...] 형태다. 단일 요청과 배치 요청을 하나의 클래스로 처리하기 위함이다. textstr(단일) 또는 List[str](배치)이 될 수 있고, sampling_paramsDict(단일) 또는 List[Dict](배치)이 될 수 있다.

배치 정규화: normalize_batch_and_arguments

GenerateReqInput의 가장 중요한 메서드다. 다양한 입력 형태를 일관된 형태로 정규화한다.

def normalize_batch_and_arguments(self):
    self._validate_inputs()
    self._determine_batch_size()
    self._handle_parallel_sampling()

    if self.is_single:
        self._normalize_single_inputs()
    else:
        self._normalize_batch_inputs()

    self._validate_rid_uniqueness()

정규화 과정은 5단계다. 입력 검증, 배치 크기 결정, parallel sampling 처리, 입력 정규화, RID 유일성 검증 순이다. _determine_batch_size에서 is_single 플래그가 설정되어, 이후 파이프라인 전체에서 단일/배치 분기에 사용된다.

parallel sampling(n > 1)은 특히 흥미로운 부분이다.

def _handle_parallel_sampling(self):
    if self.sampling_params is None:
        self.parallel_sample_num = 1
        return
    elif isinstance(self.sampling_params, dict):
        self.parallel_sample_num = self.sampling_params.get("n", 1)

    if self.parallel_sample_num > 1 and self.is_single:
        self.is_single = False
        if self.text is not None:
            self.text = [self.text]

n=3이면 같은 프롬프트를 3번 샘플링해야 하므로, 단일 요청을 배치로 변환하고 입력을 parallel_sample_num만큼 복제한다.

getitem: 배치 인덱싱

배치 요청에서 개별 요청을 추출하는 인덱서다. TokenizerManager가 배치의 각 요청을 개별 처리할 때 사용한다.

def __getitem__(self, i):
    cache = self.__dict__.setdefault("_sub_obj_cache", {})
    if i in cache:
        return cache[i]
    sub = GenerateReqInput(
        text=self.text[i] if self.text is not None else None,
        input_ids=self.input_ids[i] if self.input_ids is not None else None,
        image_data=self.image_data[i],
        sampling_params=self.sampling_params[i],
        rid=self.rid[i],
        return_logprob=self.return_logprob[i],
        stream=self.stream,
        # ...
    )
    cache[i] = sub
    return sub

_sub_obj_cache로 캐싱하여 같은 인덱스에 대한 반복 접근 시 동일한 객체를 반환한다. 이는 여러 코드 경로에서 obj[i]를 호출할 때 서로 다른 객체가 생성되어 발생할 수 있는 미묘한 버그를 방지한다.

TokenizedGenerateReqInput: 토큰화된 요청

TokenizerManager가 토큰화를 완료한 후 Scheduler에 전달하는 데이터 구조다.

@dataclass
class TokenizedGenerateReqInput(BaseReq):
    input_text: str
    input_ids: List[int]
    mm_inputs: object
    sampling_params: SamplingParams
    return_logprob: bool
    logprob_start_len: int
    top_logprobs_num: int
    token_ids_logprob: List[int]
    stream: bool

    return_hidden_states: bool = False
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
    session_params: Optional[SessionParams] = None
    lora_id: Optional[str] = None
    custom_logit_processor: Optional[str] = None

    # For disaggregated inference
    bootstrap_host: Optional[str] = None
    bootstrap_port: Optional[int] = None

GenerateReqInput과 비교하면 중요한 차이가 있다. sampling_paramsDict에서 SamplingParams 객체로 파싱되었고, text/input_ids의 Union 타입이 사라지고 확정된 단일 타입이 되었다. mm_inputs는 멀티모달 전처리 결과를 담는다. 이 변환이 TokenizerManager의 핵심 역할이다. 클라이언트의 유연한 입력을 Scheduler가 효율적으로 처리할 수 있는 정규화된 형태로 변환한다.

BatchTokenIDOutput: Scheduler에서 Detokenizer로

Scheduler가 forward pass 후 생성한 토큰 ID를 DetokenizerManager에 전달하는 배치 출력이다.

@dataclass
class BatchTokenIDOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin):
    finished_reasons: List[BaseFinishReason]
    # For incremental decoding
    decoded_texts: List[str]
    decode_ids: List[int]
    read_offsets: List[int]
    output_ids: Optional[List[int]]
    # Detokenization configs
    skip_special_tokens: List[bool]
    spaces_between_special_tokens: List[bool]
    no_stop_trim: List[bool]

    # Token counts
    prompt_tokens: List[int]
    reasoning_tokens: List[int]
    completion_tokens: List[int]
    cached_tokens: List[int]

    # Logprobs
    input_token_logprobs_val: List[float]
    output_token_logprobs_val: List[float]
    # ...

이 구조체는 SpeculativeDecodingMetricsMixin을 상속한다. speculative decoding 메트릭(verify count, accepted tokens, acceptance histogram)을 담는 mixin으로, 코드 중복을 방지한다.

@dataclass
class SpeculativeDecodingMetricsMixin:
    spec_verify_ct: List[int]
    spec_accepted_tokens: List[int]
    spec_acceptance_histogram: List[List[int]]

decoded_texts, decode_ids, read_offsets는 증분 디코딩을 위한 필드다. DetokenizerManager의 DecodeStatus와 직접 대응된다. skip_special_tokensspaces_between_special_tokens는 요청별로 다를 수 있는 디토큰화 설정이다.

BatchStrOutput: Detokenizer에서 Tokenizer로

DetokenizerManager가 토큰 ID를 텍스트로 변환한 후 TokenizerManager에 전달하는 최종 출력이다.

@dataclass
class BatchStrOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin):
    finished_reasons: List[dict]
    output_strs: List[str]
    output_ids: Optional[List[int]]

    prompt_tokens: List[int]
    completion_tokens: List[int]
    reasoning_tokens: List[int]
    cached_tokens: List[int]

    # Logprobs (passed through from BatchTokenIDOutput)
    input_token_logprobs_val: List[float]
    output_token_logprobs_val: List[float]
    # ...

BatchTokenIDOutput에서 decode_ids, read_offsets, skip_special_tokens 등 디토큰화 전용 필드가 사라지고, 대신 output_strs(디코딩된 문자열)가 추가되었다. logprobs, token counts, speculative decoding 메트릭 등은 그대로 pass-through된다.

BatchEmbeddingOutput: 임베딩 모델 출력

임베딩 모델의 출력은 디토큰화가 필요 없으므로 별도의 데이터 구조를 사용한다.

@dataclass
class BatchEmbeddingOutput(BaseBatchReq):
    finished_reasons: List[BaseFinishReason]
    embeddings: Union[List[List[float]], List[Dict[int, float]]]
    prompt_tokens: List[int]
    cached_tokens: List[int]
    placeholder_tokens_idx: List[Optional[List[int]]]
    placeholder_tokens_val: List[Optional[List[int]]]
    retraction_counts: List[int]

SpeculativeDecodingMetricsMixin을 상속하지 않는다. 임베딩 모델에서는 speculative decoding이 의미 없기 때문이다. embeddingsList[List[float]](dense) 또는 List[Dict[int, float]](sparse) 형태를 모두 지원한다.

왜 이 설계인가

Union 타입의 유연한 입력: GenerateReqInput의 거의 모든 필드가 단일 값과 리스트를 모두 허용한다. 이는 OpenAI 호환 API의 다양한 입력 형태를 하나의 클래스로 수용하기 위함이다. normalize_batch_and_arguments에서 이 유연성을 정규화하여, 이후 파이프라인은 일관된 형태만 처리한다.

dataclass + pickle: ZMQ IPC에서 직렬화 방식으로 pickle을 사용한다. protobuf나 flatbuffers 대비 Python 객체를 그대로 전달할 수 있어 개발 생산성이 높고, 같은 머신 내 IPC에서는 성능 오버헤드가 미미하다. @dataclass__init__, __repr__, __eq__를 자동 생성하여 보일러플레이트를 줄인다.

Mixin 패턴: SpeculativeDecodingMetricsMixinBatchTokenIDOutputBatchStrOutput에서 공유되는 speculative decoding 필드를 하나로 묶는다. 새로운 메트릭이 추가될 때 한 곳만 수정하면 된다.

캐싱된 인덱서: __getitem___sub_obj_cache는 같은 인덱스에 대해 항상 동일한 객체를 반환하여, 여러 코드 경로에서 발생할 수 있는 상태 불일치를 방지한다. 미묘하지만 중요한 설계 결정이다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글