[SGLang] ScheduleBatch & Req: 배치 데이터 구조의 설계와 생명주기
들어가며
LLM 서빙에서 하나의 요청은 입력 토큰의 prefill, 출력 토큰의 decode, 그리고 결과 반환까지 복잡한 생명주기를 거친다. SGLang은 이 과정을 Req와 ScheduleBatch라는 두 핵심 데이터 구조로 관리한다.
이 글에서는 python/sglang/srt/managers/schedule_batch.py를 중심으로, 요청 하나가 시스템에 들어와서 결과를 반환하기까지의 데이터 흐름을 분석한다.
구조도: 데이터 구조 변환 흐름
소스 코드 주석에 명시된 3단계 변환 파이프라인이 설계의 핵심이다.
ScheduleBatch ──> ModelWorkerBatch ──> ForwardBatch
(CPU 스케줄러) (CPU->GPU 변환) (GPU 텐서)
┌─────────────────────────────────────────────────────┐
│ ScheduleBatch │
│ - reqs: List[Req] # 고수준 스케줄링 데이터 │
│ - req_to_token_pool # 메모리 풀 참조 │
│ - tree_cache # RadixCache 참조 │
│ - forward_mode # EXTEND or DECODE │
│ - seq_lens, input_ids # 배치 텐서 (CPU) │
├─────────────────────────────────────────────────────┤
│ ModelWorkerBatch │
│ - ScheduleBatch의 부분집합 │
│ - GPU forward에 필요한 데이터만 포함 │
├─────────────────────────────────────────────────────┤
│ ForwardBatch │
│ - GPU 텐서 위주의 저수준 데이터 │
│ - 모델 forward 호출에 직접 사용 │
└─────────────────────────────────────────────────────┘
Req 클래스: 요청의 전체 상태
Req는 단일 요청의 입출력 상태, 메모리 정보, 샘플링 파라미터를 모두 담는 클래스다.
class Req(ReqDllmMixin):
"""The input and output status of a request."""
def __init__(
self,
rid: str,
origin_input_text: str,
origin_input_ids: List[int],
sampling_params: SamplingParams,
return_logprob: bool = False,
stream: bool = False,
lora_id: Optional[str] = None,
...
):
입출력 상태 관리
요청의 토큰 상태는 세 가지 리스트로 관리된다.
# 원본 입력 토큰
self.origin_input_ids = origin_input_ids
# 디코딩 단계에서 생성된 출력 토큰
self.output_ids = []
# fill_ids = origin_input_ids + output_ids (chunked 시 업데이트)
self.fill_ids = []
fill_ids는 현재 forward에 사용할 전체 토큰 시퀀스를 나타낸다. Chunked Prefill에서는 이 값이 청크 단위로 잘려서 업데이트된다.
KV 캐시 메모리 관리
요청 수준의 KV 캐시 추적은 세 가지 길이 변수로 이루어진다.
# KV 캐시에 커밋된 토큰 수
self.kv_committed_len = 0
# KV 캐시에 할당된 토큰 수
self.kv_allocated_len = 0
# 커밋된 KV가 해제되었는지
self.kv_committed_freed = False
kv_committed_len과 kv_allocated_len의 차이는 중요하다. 할당은 되었지만 아직 forward가 완료되지 않은 슬롯이 존재할 수 있기 때문이다.
Prefix 캐시 정보
RadixCache와의 연결은 다음 필드들이 담당한다.
# KV 캐시에서 재사용하는 공유 prefix의 인덱스
self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
# Prefill할 토큰 수
self.extend_input_len = 0
# RadixCache 트리의 마지막 노드
self.last_node: Any = None
prefix_indices는 이미 캐시에 존재하는 토큰의 KV 캐시 위치를 가리킨다. extend_input_len은 실제로 GPU에서 연산해야 할 신규 토큰 수다. 캐시 히트율이 높을수록 extend_input_len이 줄어들어 prefill이 빨라진다.
종료 조건
요청의 종료는 타입 시스템으로 구분된다.
class FINISH_MATCHED_TOKEN(BaseFinishReason): ... # stop token 매칭
class FINISH_MATCHED_STR(BaseFinishReason): ... # stop string 매칭
class FINISH_LENGTH(BaseFinishReason): ... # max_new_tokens 도달
class FINISH_ABORT(BaseFinishReason): ... # 에러/취소
ScheduleBatch 클래스: 배치의 모든 것
ScheduleBatch는 여러 Req를 묶어 하나의 GPU forward 단위로 관리한다.
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
"""Store all information of a batch on the scheduler."""
reqs: List[Req]
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
tree_cache: BasePrefixCache = None
forward_mode: ForwardMode = None # EXTEND or DECODE
batch_is_full: bool = False
배치 생성: init_new
새 prefill 배치는 init_new 팩토리 메서드로 생성된다.
@classmethod
def init_new(cls, reqs, req_to_token_pool,
token_to_kv_pool_allocator, tree_cache,
model_config, enable_overlap, spec_algorithm,
chunked_req=None):
return_logprob = any(req.return_logprob for req in reqs)
return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
tree_cache=tree_cache,
return_logprob=return_logprob,
has_stream=any(req.stream for req in reqs),
...
)
배치 내 하나의 요청이라도 return_logprob이나 stream을 요구하면, 배치 전체가 해당 모드로 동작한다.
Prefill과 Decode 전환
배치의 forward mode에 따라 텐서 준비 방식이 달라진다.
def prepare_for_extend(self):
"""Prefill 배치 준비"""
self.forward_mode = ForwardMode.EXTEND
input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = [len(r.fill_ids) for r in reqs]
prefix_lens = [len(r.prefix_indices) for r in reqs]
def prepare_for_decode(self):
"""Decode 배치 준비"""
self.forward_mode = ForwardMode.DECODE
self.input_embeds = None # Prefill 텐서 정리
prepare_for_extend에서 fill_ids[len(prefix_indices):]로 prefix를 제외한 신규 토큰만 추출하는 부분이 핵심이다. 이것이 RadixCache를 통한 KV 캐시 재사용의 핵심 메커니즘이다.
배치 병합과 필터링
Continuous Batching의 핵심 연산인 merge_batch와 filter_batch가 있다.
def merge_batch(self, other: "ScheduleBatch"):
"""Prefill이 끝난 요청을 running batch에 병합"""
self.sampling_info.merge_batch(other.sampling_info)
self.req_pool_indices = torch.cat(
[self.req_pool_indices, other.req_pool_indices]
)
self.reqs.extend(other.reqs)
def filter_batch(self, chunked_req_to_exclude=None):
"""완료된 요청을 배치에서 제거"""
keep_indices = [
i for i in range(len(self.reqs))
if not self.reqs[i].finished()
and self.reqs[i] not in chunked_req_to_exclude
]
요청의 생명주기
API 요청 수신
│
v
Req 객체 생성 (rid, input_ids, sampling_params)
│
v
waiting_queue에 추가
│
v
SchedulePolicy로 우선순위 결정
│
v
PrefillAdder.add_one_req() ──> 메모리 예산 확인
│
v
ScheduleBatch.init_new() ──> Prefill 배치 구성
│
v
prepare_for_extend() ──> GPU forward 실행
│
v
running_batch에 merge ──> Decode 반복
│
v
finished_reason 설정 ──> filter_batch()로 제거
│
v
결과 반환
왜 이 설계인가
1. 3단계 변환 파이프라인: ScheduleBatch -> ModelWorkerBatch -> ForwardBatch의 변환은 관심사 분리를 극대화한다. 스케줄러는 GPU 텐서를 알 필요가 없고, 모델 러너는 스케줄링 정책을 알 필요가 없다.
2. 요청 수준 KV 캐시 추적: Req에 kv_committed_len과 kv_allocated_len을 분리하여, Chunked Prefill이나 Retraction 시에도 메모리 누수 없이 정확한 관리가 가능하다.
3. batch_is_full 플래그: 매 스텝마다 메모리 예산을 재계산하는 대신, 이전 스텝의 결과를 캐시하여 불필요한 연산을 줄인다. 새 요청이 완료되거나 추가될 때만 이 플래그가 리셋된다.
관련 포스트
- [SGLang] Zero-Overhead CPU Scheduler: 배치 스케줄링의 핵심 설계
- [SGLang] 스케줄링 정책: FCFS, SJF, Age-Penalty 비교 분석
- [SGLang] Continuous Batching & Chunked Prefill: 동적 배칭의 핵심
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Zero-Overhead CPU Scheduler: 배치 스케줄링의 핵심 설계
- 현재글 : [SGLang] ScheduleBatch & Req: 배치 데이터 구조의 설계와 생명주기
- 다음글 [SGLang] 스케줄링 정책: FCFS, LPM, LOF, DFS-Weight 비교 분석
댓글