본문으로 건너뛰기

[SGLang] Disaggregated Prefill 서버: 프리필 전용 서버 구현

들어가며

Disaggregated Serving에서 Prefill 서버는 프롬프트를 처리하여 KV 캐시를 생성하고, 이를 Decode 서버로 전송하는 역할을 담당한다. SGLang의 Prefill 서버는 3단계 큐 시스템(Bootstrap → Waiting → Inflight)으로 요청의 라이프사이클을 관리한다.

구조도

  요청 도착
     │
     ▼
┌─────────────────────────────────────────────┐
│ 1. Bootstrap Queue                          │
│    - KVSender 초기화                         │
│    - 핸드셰이크 대기                          │
│    - 완료 시 Waiting Queue로 이동             │
└──────────────────┬──────────────────────────┘
                   │ pop_bootstrapped()
                   ▼
┌─────────────────────────────────────────────┐
│ 2. Waiting Queue                            │
│    - PrefillAdder가 배치 구성                 │
│    - GPU forward 실행                        │
│    - 완료 시 Inflight Queue로 이동            │
└──────────────────┬──────────────────────────┘
                   │ send_kv_chunk()
                   ▼
┌─────────────────────────────────────────────┐
│ 3. Inflight Queue                           │
│    - KV 전송 상태 폴링 (non-blocking)         │
│    - 전송 완료 시 KV 캐시 해제 + 응답          │
└─────────────────────────────────────────────┘

핵심 코드 분석

PrefillBootstrapQueue: 부트스트랩 단계

python/sglang/srt/disaggregation/prefill.pyPrefillBootstrapQueue는 요청이 도착하면 KVSender를 초기화하고 핸드셰이크를 관리한다.

class PrefillBootstrapQueue:
    def __init__(self, token_to_kv_pool, req_to_metadata_buffer_idx_allocator,
                 metadata_buffers, tp_rank, tp_size, ...):
        self.queue: List[Req] = []
        self.kv_manager = self._init_kv_manager()

요청 추가 시 KVSender를 생성하여 Decode 서버와 연결을 수립한다.

def add(self, req: Req, num_kv_heads: int) -> None:
    if self._check_if_req_exceed_kv_capacity(req):
        return
    backend = (TransferBackend.FAKE
               if req.bootstrap_host == FAKE_BOOTSTRAP_HOST
               else self.transfer_backend)
    kv_sender_class = get_kv_class(backend, KVClassType.SENDER)
    req.disagg_kv_sender = kv_sender_class(
        mgr=self.kv_manager,
        bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
        bootstrap_room=req.bootstrap_room,
        dest_tp_ranks=[self.tp_rank],
        pp_rank=self.pp_rank,
    )
    self.queue.append(req)

부트스트랩 완료 감지

pop_bootstrapped 메서드는 핸드셰이크가 완료된 요청을 큐에서 꺼낸다.

def pop_bootstrapped(self, return_failed_reqs=False, ...):
    polls = poll_and_all_reduce_attn_cp_tp_group(
        [req.disagg_kv_sender for req in self.queue],
        self.scheduler.attn_cp_cpu_group,
        self.scheduler.attn_tp_cpu_group,
    )
    for i, (req, poll) in enumerate(zip(self.queue, polls)):
        if poll == KVPoll.WaitingForInput:
            req.metadata_buffer_index = self.req_to_metadata_buffer_idx_allocator.alloc()
            num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
            req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
            bootstrapped_reqs.append(req)

TP 전체 랭크에 대해 all-reduce로 상태를 동기화한 뒤, WaitingForInput 상태가 된 요청에 메타데이터 버퍼를 할당한다.

이벤트 루프: 스케줄러 통합

Prefill 서버의 메인 이벤트 루프는 요청 수신, 배치 구성, forward 실행, KV 전송을 반복한다.

def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
    while True:
        recv_reqs = self.recv_requests()
        self.process_input_requests(recv_reqs)
        self.waiting_queue.extend(
            self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
        )
        batch = self.get_next_disagg_prefill_batch_to_run()
        if batch:
            result = self.run_batch(batch)
            self.process_batch_result(batch, result)
        self.process_disagg_prefill_inflight_queue()

KV 캐시 청크 전송

Prefill이 완료되면 KV 캐시를 페이지 단위로 Decode 서버에 전송한다. Chunked Prefill 시에는 중간 청크도 전송한다.

def send_kv_chunk(self: Scheduler, req: Req, last_chunk=False, end_idx=None):
    page_size = self.token_to_kv_pool_allocator.page_size
    start_idx = req.start_send_idx
    end_idx = end_idx if end_idx is not None else min(len(req.fill_ids), len(req.origin_input_ids))

    if not last_chunk:
        end_idx = end_idx - end_idx % page_size  # 페이지 경계 정렬

    kv_indices = self.req_to_token_pool.req_to_token[
        req.req_pool_idx, start_idx:end_idx
    ].cpu().numpy()
    page_indices = kv_to_page_indices(kv_indices, page_size)
    req.disagg_kv_sender.send(page_indices, state_indices)

마지막 청크가 아닌 경우, 부분 페이지를 다음 전송으로 미루기 위해 end_idx를 페이지 크기로 정렬한다.

배치 결과 처리

Prefill forward가 완료되면 첫 번째 출력 토큰과 메타데이터를 설정하고, KV 전송을 시작한다.

def process_batch_result_disagg_prefill(self, batch, result):
    for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
        if req.is_chunked <= 0:
            req.output_ids.append(next_token_id)
            self.tree_cache.cache_unfinished_req(req)
            self.disagg_prefill_inflight_queue.append(req)
            self.send_kv_chunk(req, last_chunk=True)

Inflight 큐 처리: 전송 완료 확인

전송 중인 요청들의 상태를 비동기로 폴링하여 완료된 요청의 KV 캐시를 해제한다.

def process_disagg_prefill_inflight_queue(self):
    polls = poll_and_all_reduce_attn_cp_tp_group(
        [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
        self.attn_cp_cpu_group, self.attn_tp_cpu_group,
    )
    for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
        if poll == KVPoll.Success:
            release_kv_cache(req, self.tree_cache)
            req.finished_reason = FINISH_LENGTH(length=0)
            done_reqs.append(req)

전송이 성공하면 KV 캐시를 언락하고, 전송 지연 시간과 속도 메트릭을 기록한다.

설계 근거

max_new_tokens = 1 제약

Prefill 서버에서는 max_new_tokens를 1로 강제한다.

def _process_req(self, req: Req) -> None:
    req.sampling_params.max_new_tokens = 1

이는 PrefillAdder의 메모리 추정이 정확하도록 보장한다. Prefill은 첫 번째 토큰만 생성하고, 나머지는 Decode 서버에서 처리한다.

Overlap 모드

event_loop_overlap_disagg_prefill은 현재 배치의 forward와 이전 배치의 결과 처리를 파이프라인으로 겹치는 최적화를 제공한다. 이를 통해 GPU idle 시간을 줄인다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글