본문으로 건너뛰기

[SGLang] ScheduleBatch & Req: 배치 데이터 구조의 설계와 생명주기

들어가며

LLM 서빙에서 하나의 요청은 입력 토큰의 prefill, 출력 토큰의 decode, 그리고 결과 반환까지 복잡한 생명주기를 거친다. SGLang은 이 과정을 ReqScheduleBatch라는 두 핵심 데이터 구조로 관리한다.

이 글에서는 python/sglang/srt/managers/schedule_batch.py를 중심으로, 요청 하나가 시스템에 들어와서 결과를 반환하기까지의 데이터 흐름을 분석한다.

구조도: 데이터 구조 변환 흐름

소스 코드 주석에 명시된 3단계 변환 파이프라인이 설계의 핵심이다.

ScheduleBatch ──> ModelWorkerBatch ──> ForwardBatch
  (CPU 스케줄러)     (CPU->GPU 변환)      (GPU 텐서)

┌─────────────────────────────────────────────────────┐
│ ScheduleBatch                                       │
│  - reqs: List[Req]        # 고수준 스케줄링 데이터   │
│  - req_to_token_pool      # 메모리 풀 참조          │
│  - tree_cache             # RadixCache 참조         │
│  - forward_mode           # EXTEND or DECODE        │
│  - seq_lens, input_ids    # 배치 텐서 (CPU)         │
├─────────────────────────────────────────────────────┤
│ ModelWorkerBatch                                    │
│  - ScheduleBatch의 부분집합                         │
│  - GPU forward에 필요한 데이터만 포함                │
├─────────────────────────────────────────────────────┤
│ ForwardBatch                                        │
│  - GPU 텐서 위주의 저수준 데이터                     │
│  - 모델 forward 호출에 직접 사용                     │
└─────────────────────────────────────────────────────┘

Req 클래스: 요청의 전체 상태

Req는 단일 요청의 입출력 상태, 메모리 정보, 샘플링 파라미터를 모두 담는 클래스다.

class Req(ReqDllmMixin):
    """The input and output status of a request."""

    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: List[int],
        sampling_params: SamplingParams,
        return_logprob: bool = False,
        stream: bool = False,
        lora_id: Optional[str] = None,
        ...
    ):

입출력 상태 관리

요청의 토큰 상태는 세 가지 리스트로 관리된다.

# 원본 입력 토큰
self.origin_input_ids = origin_input_ids
# 디코딩 단계에서 생성된 출력 토큰
self.output_ids = []
# fill_ids = origin_input_ids + output_ids (chunked 시 업데이트)
self.fill_ids = []

fill_ids는 현재 forward에 사용할 전체 토큰 시퀀스를 나타낸다. Chunked Prefill에서는 이 값이 청크 단위로 잘려서 업데이트된다.

KV 캐시 메모리 관리

요청 수준의 KV 캐시 추적은 세 가지 길이 변수로 이루어진다.

# KV 캐시에 커밋된 토큰 수
self.kv_committed_len = 0
# KV 캐시에 할당된 토큰 수
self.kv_allocated_len = 0
# 커밋된 KV가 해제되었는지
self.kv_committed_freed = False

kv_committed_lenkv_allocated_len의 차이는 중요하다. 할당은 되었지만 아직 forward가 완료되지 않은 슬롯이 존재할 수 있기 때문이다.

Prefix 캐시 정보

RadixCache와의 연결은 다음 필드들이 담당한다.

# KV 캐시에서 재사용하는 공유 prefix의 인덱스
self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
# Prefill할 토큰 수
self.extend_input_len = 0
# RadixCache 트리의 마지막 노드
self.last_node: Any = None

prefix_indices는 이미 캐시에 존재하는 토큰의 KV 캐시 위치를 가리킨다. extend_input_len은 실제로 GPU에서 연산해야 할 신규 토큰 수다. 캐시 히트율이 높을수록 extend_input_len이 줄어들어 prefill이 빨라진다.

종료 조건

요청의 종료는 타입 시스템으로 구분된다.

class FINISH_MATCHED_TOKEN(BaseFinishReason): ...  # stop token 매칭
class FINISH_MATCHED_STR(BaseFinishReason): ...    # stop string 매칭
class FINISH_LENGTH(BaseFinishReason): ...          # max_new_tokens 도달
class FINISH_ABORT(BaseFinishReason): ...           # 에러/취소

ScheduleBatch 클래스: 배치의 모든 것

ScheduleBatch는 여러 Req를 묶어 하나의 GPU forward 단위로 관리한다.

class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
    """Store all information of a batch on the scheduler."""

    reqs: List[Req]
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
    tree_cache: BasePrefixCache = None
    forward_mode: ForwardMode = None  # EXTEND or DECODE
    batch_is_full: bool = False

배치 생성: init_new

새 prefill 배치는 init_new 팩토리 메서드로 생성된다.

@classmethod
def init_new(cls, reqs, req_to_token_pool,
             token_to_kv_pool_allocator, tree_cache,
             model_config, enable_overlap, spec_algorithm,
             chunked_req=None):
    return_logprob = any(req.return_logprob for req in reqs)
    return cls(
        reqs=reqs,
        req_to_token_pool=req_to_token_pool,
        token_to_kv_pool_allocator=token_to_kv_pool_allocator,
        tree_cache=tree_cache,
        return_logprob=return_logprob,
        has_stream=any(req.stream for req in reqs),
        ...
    )

배치 내 하나의 요청이라도 return_logprob이나 stream을 요구하면, 배치 전체가 해당 모드로 동작한다.

Prefill과 Decode 전환

배치의 forward mode에 따라 텐서 준비 방식이 달라진다.

def prepare_for_extend(self):
    """Prefill 배치 준비"""
    self.forward_mode = ForwardMode.EXTEND
    input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs]
    extend_num_tokens = sum(len(ids) for ids in input_ids)
    seq_lens = [len(r.fill_ids) for r in reqs]
    prefix_lens = [len(r.prefix_indices) for r in reqs]

def prepare_for_decode(self):
    """Decode 배치 준비"""
    self.forward_mode = ForwardMode.DECODE
    self.input_embeds = None  # Prefill 텐서 정리

prepare_for_extend에서 fill_ids[len(prefix_indices):]로 prefix를 제외한 신규 토큰만 추출하는 부분이 핵심이다. 이것이 RadixCache를 통한 KV 캐시 재사용의 핵심 메커니즘이다.

배치 병합과 필터링

Continuous Batching의 핵심 연산인 merge_batchfilter_batch가 있다.

def merge_batch(self, other: "ScheduleBatch"):
    """Prefill이 끝난 요청을 running batch에 병합"""
    self.sampling_info.merge_batch(other.sampling_info)
    self.req_pool_indices = torch.cat(
        [self.req_pool_indices, other.req_pool_indices]
    )
    self.reqs.extend(other.reqs)

def filter_batch(self, chunked_req_to_exclude=None):
    """완료된 요청을 배치에서 제거"""
    keep_indices = [
        i for i in range(len(self.reqs))
        if not self.reqs[i].finished()
        and self.reqs[i] not in chunked_req_to_exclude
    ]

요청의 생명주기

  API 요청 수신
       │
       v
  Req 객체 생성 (rid, input_ids, sampling_params)
       │
       v
  waiting_queue에 추가
       │
       v
  SchedulePolicy로 우선순위 결정
       │
       v
  PrefillAdder.add_one_req() ──> 메모리 예산 확인
       │
       v
  ScheduleBatch.init_new() ──> Prefill 배치 구성
       │
       v
  prepare_for_extend() ──> GPU forward 실행
       │
       v
  running_batch에 merge ──> Decode 반복
       │
       v
  finished_reason 설정 ──> filter_batch()로 제거
       │
       v
  결과 반환

왜 이 설계인가

1. 3단계 변환 파이프라인: ScheduleBatch -> ModelWorkerBatch -> ForwardBatch의 변환은 관심사 분리를 극대화한다. 스케줄러는 GPU 텐서를 알 필요가 없고, 모델 러너는 스케줄링 정책을 알 필요가 없다.

2. 요청 수준 KV 캐시 추적: Reqkv_committed_lenkv_allocated_len을 분리하여, Chunked Prefill이나 Retraction 시에도 메모리 누수 없이 정확한 관리가 가능하다.

3. batch_is_full 플래그: 매 스텝마다 메모리 예산을 재계산하는 대신, 이전 스텝의 결과를 캐시하여 불필요한 연산을 줄인다. 새 요청이 완료되거나 추가될 때만 이 플래그가 리셋된다.

관련 포스트

  • [SGLang] Zero-Overhead CPU Scheduler: 배치 스케줄링의 핵심 설계
  • [SGLang] 스케줄링 정책: FCFS, SJF, Age-Penalty 비교 분석
  • [SGLang] Continuous Batching & Chunked Prefill: 동적 배칭의 핵심

참고

댓글

관련 포스트

SGLang 의 다른글