본문으로 건너뛰기

[SGLang] Continuous Batching & Chunked Prefill: 동적 배칭의 핵심

들어가며

전통적인 Static Batching에서는 배치 내 모든 요청이 끝날 때까지 새 요청을 추가할 수 없다. 가장 긴 요청 하나가 전체 배치를 지연시키는 문제가 발생한다. Continuous Batching은 개별 요청이 끝나는 즉시 그 자리에 새 요청을 채워 GPU 활용률을 극대화한다. Chunked Prefill은 여기서 한 단계 더 나아가, 긴 프롬프트를 청크 단위로 나누어 decode 배치와 혼합 실행한다.

이 글에서는 python/sglang/srt/managers/scheduler.pyget_new_batch_prefill 메서드를 중심으로, SGLang의 Continuous Batching과 Chunked Prefill 구현을 분석한다.

Static Batching vs Continuous Batching

[Static Batching - 빈 슬롯 낭비]

 Req A: |=====prefill=====|===decode===|===decode===|===decode===|
 Req B: |==prefill==|=dec=|=dec=|=dec=|xxxx idle xxxx|
 Req C: |=prefill=|=d=|xxxx idle xxxxxxxxxxxxxxxxxxxxx|
 Req D: |                 (대기 중...)                | 시작 가능 |
        t0               t1          t2             t3          t4

  * Req C가 t1에 끝나도, Req A가 끝나는 t3까지 빈 슬롯 유지
  * Req D는 전체 배치가 끝나는 t3까지 대기

[Continuous Batching - 즉시 채움]

 Req A: |=====prefill=====|===decode===|===decode===|===decode===|
 Req B: |==prefill==|=dec=|=dec=|=dec=|
 Req C: |=prefill=|=d=|
 Req D:                |=pf=|=d=|=d=|=d=|=d=|=d=|  <- C 끝나자마자 시작
 Req E:                          |=pf=|=d=|=d=|     <- B 끝나자마자 시작
        t0               t1          t2             t3

  * 빈 슬롯 즉시 재활용 -> GPU 활용률 극대화

구조도: Prefill 배치 결정 흐름

get_new_batch_prefill()
       │
       v
  batch_is_full? ──Yes──> return None
       │ No
       v
  SchedulePolicy.calc_priority()
  (대기 큐 정렬)
       │
       v
  PrefillAdder 생성
  (메모리 예산 계산)
       │
       v
  ┌────────────────────────┐
  │ chunked_req 있으면     │
  │ add_chunked_req() 먼저 │
  └───────┬────────────────┘
          v
  ┌─────────────────────────────┐
  │ for req in waiting_queue:   │
  │   req.init_next_round_input │
  │   adder.add_one_req(req)    │
  │   if NO_TOKEN: break        │
  └───────┬─────────────────────┘
          v
  ScheduleBatch.init_new(can_run_list)

get_new_batch_prefill: 핵심 진입점

def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
    prefill_delayer_single_pass = None
    if self.prefill_delayer:
        _, token_usage, _, _ = self._get_token_info()
        prefill_delayer_single_pass = PrefillDelayerSinglePassExecutor(
            self.prefill_delayer, token_usage=token_usage
        )

    ret = self._get_new_batch_prefill_raw(
        prefill_delayer_single_pass=prefill_delayer_single_pass
    )

    if self.prefill_delayer:
        prefill_delayer_single_pass.finalize(
            actual_prefill=ret is not None
        )
    return ret

PrefillDelayer는 토큰 사용량이 높을 때 prefill을 지연시켜 decode throughput을 보호하는 메커니즘이다.

_get_new_batch_prefill_raw: 배치 구성 로직

조기 종료 조건

def _get_new_batch_prefill_raw(self, prefill_delayer_single_pass):
    # 배치가 이미 꽉 찼거나 대기 큐가 비었으면 건너뜀
    if (
        self.running_batch.batch_is_full
        or len(self.waiting_queue) == 0
    ) and self.chunked_req is None:
        return None

    running_bs = len(self.running_batch.reqs)
    if self.get_num_allocatable_reqs(running_bs) <= 0 \
       and self.chunked_req is not None:
        self.running_batch.batch_is_full = True
        return None

batch_is_full 플래그는 이전 스텝에서 토큰이 부족하다고 판단된 경우 true로 설정되어, 불필요한 메모리 예산 계산을 방지한다.

PrefillAdder 생성

    adder = PrefillAdder(
        self.page_size,
        self.tree_cache,
        self.token_to_kv_pool_allocator,
        self.running_batch,
        self.new_token_ratio,
        self.max_prefill_tokens,     # rem_input_tokens
        chunked_prefill_size,        # rem_chunk_tokens
        running_bs if self.is_mixed_chunk else 0,
        ...
    )

PrefillAdder는 두 가지 토큰 예산을 관리한다.

  • rem_input_tokens: 이번 스텝에서 prefill에 사용할 수 있는 총 입력 토큰 수
  • rem_chunk_tokens: Chunked Prefill에서 하나의 청크에 허용되는 최대 토큰 수

Chunked Request 우선 처리

이전 스텝에서 중단된 chunked request가 있으면 먼저 추가한다.

    if self.chunked_req is not None:
        self.chunked_req.init_next_round_input()
        self.chunked_req = adder.add_chunked_req(self.chunked_req)

대기 큐 순회

    for req in self.waiting_queue:
        if self.running_batch.batch_is_full:
            break

        req.init_next_round_input(self.tree_cache)
        res = adder.add_one_req(
            req,
            has_chunked_req=(self.chunked_req is not None),
            truncation_align_size=self.truncation_align_size,
        )

        if res != AddReqResult.CONTINUE:
            if res == AddReqResult.NO_TOKEN:
                self.running_batch.batch_is_full = True
            break

AddReqResult의 세 가지 상태가 제어 흐름을 결정한다.

  • CONTINUE: 다음 요청도 추가 시도
  • NO_TOKEN: KV 캐시 메모리 부족, 중단
  • OTHER: 입력 토큰 예산 소진 등, 중단

Chunked Prefill 구현

긴 프롬프트를 한 번에 처리하면 decode 요청이 지연된다. Chunked Prefill은 이를 방지한다.

add_one_req에서의 청크 분할

def add_one_req(self, req, has_chunked_req, truncation_align_size):
    total_tokens = req.extend_input_len + min(
        max(req.sampling_params.max_new_tokens - len(req.output_ids), 0),
        CLIP_MAX_NEW_TOKENS,
    )

    if total_tokens >= self.rem_total_tokens:
        return AddReqResult.NO_TOKEN

    # Non-chunked prefill
    if self.rem_chunk_tokens is None \
       or input_tokens <= self.rem_chunk_tokens:
        self.can_run_list.append(req)
        self._update_prefill_budget(
            prefix_len, req.extend_input_len,
            min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
        )
    else:
        # Chunked prefill: 청크 크기로 자름
        if not has_chunked_req:
            # 청크 분할 실행
            self.new_chunked_req = adder.add_chunked_req(req)

add_chunked_req: 청크 분할 실행

def add_chunked_req(self, req):
    _rem_tokens = min(
        self.rem_chunk_tokens, int(self.rem_total_tokens)
    )
    if _rem_tokens <= 0:
        _rem_tokens = self.rem_chunk_tokens

    truncated = req.extend_input_len > _rem_tokens
    req.set_extend_input_len(
        min(req.extend_input_len, _rem_tokens)
    )
    req.fill_ids = req.fill_ids[
        :len(req.prefix_indices) + req.extend_input_len
    ]
    self.can_run_list.append(req)

    # 잘린 경우 다음 스텝에서 이어서 처리
    return req if truncated else None

truncated가 True이면 요청 객체를 반환하여 self.chunked_req에 저장한다. 다음 스텝에서 init_next_round_input()이 호출되면 남은 부분부터 이어서 처리한다.

Chunked Prefill 동작 예시

[10,000 토큰 프롬프트, chunked_prefill_size = 4096]

Step 1: Prefill 청크 1
  fill_ids[0:4096] -> GPU forward
  chunked_req 저장 (남은: 5,904 토큰)
  + Decode batch (기존 running 요청들)

Step 2: Prefill 청크 2
  fill_ids[4096:8192] -> GPU forward
  chunked_req 저장 (남은: 1,808 토큰)
  + Decode batch

Step 3: Prefill 청크 3 (마지막)
  fill_ids[8192:10000] -> GPU forward
  chunked_req = None (완료)
  -> running_batch에 merge
  + Decode batch

* 매 스텝마다 decode 요청도 함께 처리 -> TPOT 유지

Mixed Chunked Prefill

is_mixed_chunk가 활성화되면 prefill과 decode 토큰이 하나의 forward에서 함께 처리된다.

adder = PrefillAdder(
    ...
    mixed_with_decode_tokens=running_bs if self.is_mixed_chunk else 0,
)

PrefillAdder 내부에서 decode 토큰 수를 예산에서 차감한다.

def __init__(self, ...):
    self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
    if self.rem_chunk_tokens is not None:
        self.rem_chunk_tokens -= mixed_with_decode_tokens

이렇게 하면 prefill 토큰과 decode 토큰의 합이 GPU 메모리 한도를 넘지 않도록 보장한다.

메모리 예산 관리

PrefillAdder의 메모리 예산 시스템은 두 수준에서 동작한다.

@property
def rem_total_tokens(self):
    """전체 KV 캐시 풀에서 사용 가능한 토큰 수"""
    available_and_evictable = (
        self.token_to_kv_pool_allocator.available_size()
        + self.tree_cache.evictable_size()
    )
    return available_and_evictable - self.rem_total_token_offset

def budget_state(self):
    if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
        return AddReqResult.NO_TOKEN
    if self.rem_input_tokens <= 0:
        return AddReqResult.OTHER
    if self.rem_chunk_tokens is not None \
       and self.rem_chunk_tokens <= 0:
        return AddReqResult.OTHER
    return AddReqResult.CONTINUE

rem_total_tokens는 KV 캐시 풀의 가용 크기에 RadixCache에서 evict 가능한 크기를 더한 값이다. 새 요청을 추가할 때마다 rem_total_token_offset이 증가하여 예산이 줄어든다. CLIP_MAX_NEW_TOKENS (기본값 4096)으로 미래 출력 토큰 예약량을 제한하여 보수적 추정을 방지한다.

CLIP_MAX_NEW_TOKENS = int(
    os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
)

동적 청크 크기

SGLang은 고정 청크 크기 대신 동적 청크 크기도 지원한다.

chunked_prefill_size = self.chunked_prefill_size
if self.chunked_req is not None and self.enable_dynamic_chunking:
    history_len = len(self.chunked_req.prefix_indices)
    dynamic_size = self.predict_next_chunk_size(history_len)
    if dynamic_size is not None:
        chunked_prefill_size = dynamic_size

이전 청크의 prefix 길이를 기반으로 다음 청크 크기를 예측한다. Attention의 연산량이 시퀀스 길이에 따라 비선형적으로 증가하므로, 청크 크기를 적응적으로 조절하면 각 스텝의 latency를 균일하게 유지할 수 있다.

왜 이 설계인가

1. Prefill 우선, 단 하나의 chunked_req: SGLang은 동시에 하나의 chunked request만 허용한다. 여러 개를 허용하면 메모리 관리가 복잡해지고, 실제로는 하나씩 처리해도 충분한 throughput을 달성할 수 있기 때문이다.

2. 3단계 AddReqResult: CONTINUE, NO_TOKEN, OTHER의 구분은 실용적이다. NO_TOKEN은 KV 캐시 부족으로 추가 요청이 불가능하지만, OTHER는 입력 토큰 예산 소진으로 현재 배치는 실행 가능하다.

3. CLIP_MAX_NEW_TOKENS: max_new_tokens를 매우 크게 설정한 요청이 메모리 예산을 과도하게 잡아먹는 것을 방지한다. 실제 생성 토큰은 대부분 4096 이하이므로, 이 클리핑은 실질적으로 throughput을 높인다.

관련 포스트

  • [SGLang] Zero-Overhead CPU Scheduler: 배치 스케줄링의 핵심 설계
  • [SGLang] ScheduleBatch & Req: 배치 데이터 구조의 설계와 생명주기
  • [SGLang] 스케줄링 정책: FCFS, LPM, LOF, DFS-Weight 비교 분석

참고

댓글

관련 포스트

SGLang 의 다른글