[SGLang] Continuous Batching & Chunked Prefill: 동적 배칭의 핵심
들어가며
전통적인 Static Batching에서는 배치 내 모든 요청이 끝날 때까지 새 요청을 추가할 수 없다. 가장 긴 요청 하나가 전체 배치를 지연시키는 문제가 발생한다. Continuous Batching은 개별 요청이 끝나는 즉시 그 자리에 새 요청을 채워 GPU 활용률을 극대화한다. Chunked Prefill은 여기서 한 단계 더 나아가, 긴 프롬프트를 청크 단위로 나누어 decode 배치와 혼합 실행한다.
이 글에서는 python/sglang/srt/managers/scheduler.py의 get_new_batch_prefill 메서드를 중심으로, SGLang의 Continuous Batching과 Chunked Prefill 구현을 분석한다.
Static Batching vs Continuous Batching
[Static Batching - 빈 슬롯 낭비]
Req A: |=====prefill=====|===decode===|===decode===|===decode===|
Req B: |==prefill==|=dec=|=dec=|=dec=|xxxx idle xxxx|
Req C: |=prefill=|=d=|xxxx idle xxxxxxxxxxxxxxxxxxxxx|
Req D: | (대기 중...) | 시작 가능 |
t0 t1 t2 t3 t4
* Req C가 t1에 끝나도, Req A가 끝나는 t3까지 빈 슬롯 유지
* Req D는 전체 배치가 끝나는 t3까지 대기
[Continuous Batching - 즉시 채움]
Req A: |=====prefill=====|===decode===|===decode===|===decode===|
Req B: |==prefill==|=dec=|=dec=|=dec=|
Req C: |=prefill=|=d=|
Req D: |=pf=|=d=|=d=|=d=|=d=|=d=| <- C 끝나자마자 시작
Req E: |=pf=|=d=|=d=| <- B 끝나자마자 시작
t0 t1 t2 t3
* 빈 슬롯 즉시 재활용 -> GPU 활용률 극대화
구조도: Prefill 배치 결정 흐름
get_new_batch_prefill()
│
v
batch_is_full? ──Yes──> return None
│ No
v
SchedulePolicy.calc_priority()
(대기 큐 정렬)
│
v
PrefillAdder 생성
(메모리 예산 계산)
│
v
┌────────────────────────┐
│ chunked_req 있으면 │
│ add_chunked_req() 먼저 │
└───────┬────────────────┘
v
┌─────────────────────────────┐
│ for req in waiting_queue: │
│ req.init_next_round_input │
│ adder.add_one_req(req) │
│ if NO_TOKEN: break │
└───────┬─────────────────────┘
v
ScheduleBatch.init_new(can_run_list)
get_new_batch_prefill: 핵심 진입점
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
prefill_delayer_single_pass = None
if self.prefill_delayer:
_, token_usage, _, _ = self._get_token_info()
prefill_delayer_single_pass = PrefillDelayerSinglePassExecutor(
self.prefill_delayer, token_usage=token_usage
)
ret = self._get_new_batch_prefill_raw(
prefill_delayer_single_pass=prefill_delayer_single_pass
)
if self.prefill_delayer:
prefill_delayer_single_pass.finalize(
actual_prefill=ret is not None
)
return ret
PrefillDelayer는 토큰 사용량이 높을 때 prefill을 지연시켜 decode throughput을 보호하는 메커니즘이다.
_get_new_batch_prefill_raw: 배치 구성 로직
조기 종료 조건
def _get_new_batch_prefill_raw(self, prefill_delayer_single_pass):
# 배치가 이미 꽉 찼거나 대기 큐가 비었으면 건너뜀
if (
self.running_batch.batch_is_full
or len(self.waiting_queue) == 0
) and self.chunked_req is None:
return None
running_bs = len(self.running_batch.reqs)
if self.get_num_allocatable_reqs(running_bs) <= 0 \
and self.chunked_req is not None:
self.running_batch.batch_is_full = True
return None
batch_is_full 플래그는 이전 스텝에서 토큰이 부족하다고 판단된 경우 true로 설정되어, 불필요한 메모리 예산 계산을 방지한다.
PrefillAdder 생성
adder = PrefillAdder(
self.page_size,
self.tree_cache,
self.token_to_kv_pool_allocator,
self.running_batch,
self.new_token_ratio,
self.max_prefill_tokens, # rem_input_tokens
chunked_prefill_size, # rem_chunk_tokens
running_bs if self.is_mixed_chunk else 0,
...
)
PrefillAdder는 두 가지 토큰 예산을 관리한다.
rem_input_tokens: 이번 스텝에서 prefill에 사용할 수 있는 총 입력 토큰 수rem_chunk_tokens: Chunked Prefill에서 하나의 청크에 허용되는 최대 토큰 수
Chunked Request 우선 처리
이전 스텝에서 중단된 chunked request가 있으면 먼저 추가한다.
if self.chunked_req is not None:
self.chunked_req.init_next_round_input()
self.chunked_req = adder.add_chunked_req(self.chunked_req)
대기 큐 순회
for req in self.waiting_queue:
if self.running_batch.batch_is_full:
break
req.init_next_round_input(self.tree_cache)
res = adder.add_one_req(
req,
has_chunked_req=(self.chunked_req is not None),
truncation_align_size=self.truncation_align_size,
)
if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN:
self.running_batch.batch_is_full = True
break
AddReqResult의 세 가지 상태가 제어 흐름을 결정한다.
CONTINUE: 다음 요청도 추가 시도NO_TOKEN: KV 캐시 메모리 부족, 중단OTHER: 입력 토큰 예산 소진 등, 중단
Chunked Prefill 구현
긴 프롬프트를 한 번에 처리하면 decode 요청이 지연된다. Chunked Prefill은 이를 방지한다.
add_one_req에서의 청크 분할
def add_one_req(self, req, has_chunked_req, truncation_align_size):
total_tokens = req.extend_input_len + min(
max(req.sampling_params.max_new_tokens - len(req.output_ids), 0),
CLIP_MAX_NEW_TOKENS,
)
if total_tokens >= self.rem_total_tokens:
return AddReqResult.NO_TOKEN
# Non-chunked prefill
if self.rem_chunk_tokens is None \
or input_tokens <= self.rem_chunk_tokens:
self.can_run_list.append(req)
self._update_prefill_budget(
prefix_len, req.extend_input_len,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
)
else:
# Chunked prefill: 청크 크기로 자름
if not has_chunked_req:
# 청크 분할 실행
self.new_chunked_req = adder.add_chunked_req(req)
add_chunked_req: 청크 분할 실행
def add_chunked_req(self, req):
_rem_tokens = min(
self.rem_chunk_tokens, int(self.rem_total_tokens)
)
if _rem_tokens <= 0:
_rem_tokens = self.rem_chunk_tokens
truncated = req.extend_input_len > _rem_tokens
req.set_extend_input_len(
min(req.extend_input_len, _rem_tokens)
)
req.fill_ids = req.fill_ids[
:len(req.prefix_indices) + req.extend_input_len
]
self.can_run_list.append(req)
# 잘린 경우 다음 스텝에서 이어서 처리
return req if truncated else None
truncated가 True이면 요청 객체를 반환하여 self.chunked_req에 저장한다. 다음 스텝에서 init_next_round_input()이 호출되면 남은 부분부터 이어서 처리한다.
Chunked Prefill 동작 예시
[10,000 토큰 프롬프트, chunked_prefill_size = 4096]
Step 1: Prefill 청크 1
fill_ids[0:4096] -> GPU forward
chunked_req 저장 (남은: 5,904 토큰)
+ Decode batch (기존 running 요청들)
Step 2: Prefill 청크 2
fill_ids[4096:8192] -> GPU forward
chunked_req 저장 (남은: 1,808 토큰)
+ Decode batch
Step 3: Prefill 청크 3 (마지막)
fill_ids[8192:10000] -> GPU forward
chunked_req = None (완료)
-> running_batch에 merge
+ Decode batch
* 매 스텝마다 decode 요청도 함께 처리 -> TPOT 유지
Mixed Chunked Prefill
is_mixed_chunk가 활성화되면 prefill과 decode 토큰이 하나의 forward에서 함께 처리된다.
adder = PrefillAdder(
...
mixed_with_decode_tokens=running_bs if self.is_mixed_chunk else 0,
)
PrefillAdder 내부에서 decode 토큰 수를 예산에서 차감한다.
def __init__(self, ...):
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens
이렇게 하면 prefill 토큰과 decode 토큰의 합이 GPU 메모리 한도를 넘지 않도록 보장한다.
메모리 예산 관리
PrefillAdder의 메모리 예산 시스템은 두 수준에서 동작한다.
@property
def rem_total_tokens(self):
"""전체 KV 캐시 풀에서 사용 가능한 토큰 수"""
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
return available_and_evictable - self.rem_total_token_offset
def budget_state(self):
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
return AddReqResult.NO_TOKEN
if self.rem_input_tokens <= 0:
return AddReqResult.OTHER
if self.rem_chunk_tokens is not None \
and self.rem_chunk_tokens <= 0:
return AddReqResult.OTHER
return AddReqResult.CONTINUE
rem_total_tokens는 KV 캐시 풀의 가용 크기에 RadixCache에서 evict 가능한 크기를 더한 값이다. 새 요청을 추가할 때마다 rem_total_token_offset이 증가하여 예산이 줄어든다. CLIP_MAX_NEW_TOKENS (기본값 4096)으로 미래 출력 토큰 예약량을 제한하여 보수적 추정을 방지한다.
CLIP_MAX_NEW_TOKENS = int(
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
)
동적 청크 크기
SGLang은 고정 청크 크기 대신 동적 청크 크기도 지원한다.
chunked_prefill_size = self.chunked_prefill_size
if self.chunked_req is not None and self.enable_dynamic_chunking:
history_len = len(self.chunked_req.prefix_indices)
dynamic_size = self.predict_next_chunk_size(history_len)
if dynamic_size is not None:
chunked_prefill_size = dynamic_size
이전 청크의 prefix 길이를 기반으로 다음 청크 크기를 예측한다. Attention의 연산량이 시퀀스 길이에 따라 비선형적으로 증가하므로, 청크 크기를 적응적으로 조절하면 각 스텝의 latency를 균일하게 유지할 수 있다.
왜 이 설계인가
1. Prefill 우선, 단 하나의 chunked_req: SGLang은 동시에 하나의 chunked request만 허용한다. 여러 개를 허용하면 메모리 관리가 복잡해지고, 실제로는 하나씩 처리해도 충분한 throughput을 달성할 수 있기 때문이다.
2. 3단계 AddReqResult: CONTINUE, NO_TOKEN, OTHER의 구분은 실용적이다. NO_TOKEN은 KV 캐시 부족으로 추가 요청이 불가능하지만, OTHER는 입력 토큰 예산 소진으로 현재 배치는 실행 가능하다.
3. CLIP_MAX_NEW_TOKENS: max_new_tokens를 매우 크게 설정한 요청이 메모리 예산을 과도하게 잡아먹는 것을 방지한다. 실제 생성 토큰은 대부분 4096 이하이므로, 이 클리핑은 실질적으로 throughput을 높인다.
관련 포스트
- [SGLang] Zero-Overhead CPU Scheduler: 배치 스케줄링의 핵심 설계
- [SGLang] ScheduleBatch & Req: 배치 데이터 구조의 설계와 생명주기
- [SGLang] 스케줄링 정책: FCFS, LPM, LOF, DFS-Weight 비교 분석
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] 스케줄링 정책: FCFS, LPM, LOF, DFS-Weight 비교 분석
- 현재글 : [SGLang] Continuous Batching & Chunked Prefill: 동적 배칭의 핵심
- 다음글 [SGLang] Pipeline Parallelism 스케줄러: PP 믹스인 설계
댓글