[SGLang] Disaggregated Prefill 서버: 프리필 전용 서버 구현
들어가며
Disaggregated Serving에서 Prefill 서버는 프롬프트를 처리하여 KV 캐시를 생성하고, 이를 Decode 서버로 전송하는 역할을 담당한다. SGLang의 Prefill 서버는 3단계 큐 시스템(Bootstrap → Waiting → Inflight)으로 요청의 라이프사이클을 관리한다.
구조도
요청 도착
│
▼
┌─────────────────────────────────────────────┐
│ 1. Bootstrap Queue │
│ - KVSender 초기화 │
│ - 핸드셰이크 대기 │
│ - 완료 시 Waiting Queue로 이동 │
└──────────────────┬──────────────────────────┘
│ pop_bootstrapped()
▼
┌─────────────────────────────────────────────┐
│ 2. Waiting Queue │
│ - PrefillAdder가 배치 구성 │
│ - GPU forward 실행 │
│ - 완료 시 Inflight Queue로 이동 │
└──────────────────┬──────────────────────────┘
│ send_kv_chunk()
▼
┌─────────────────────────────────────────────┐
│ 3. Inflight Queue │
│ - KV 전송 상태 폴링 (non-blocking) │
│ - 전송 완료 시 KV 캐시 해제 + 응답 │
└─────────────────────────────────────────────┘
핵심 코드 분석
PrefillBootstrapQueue: 부트스트랩 단계
python/sglang/srt/disaggregation/prefill.py의 PrefillBootstrapQueue는 요청이 도착하면 KVSender를 초기화하고 핸드셰이크를 관리한다.
class PrefillBootstrapQueue:
def __init__(self, token_to_kv_pool, req_to_metadata_buffer_idx_allocator,
metadata_buffers, tp_rank, tp_size, ...):
self.queue: List[Req] = []
self.kv_manager = self._init_kv_manager()
요청 추가 시 KVSender를 생성하여 Decode 서버와 연결을 수립한다.
def add(self, req: Req, num_kv_heads: int) -> None:
if self._check_if_req_exceed_kv_capacity(req):
return
backend = (TransferBackend.FAKE
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST
else self.transfer_backend)
kv_sender_class = get_kv_class(backend, KVClassType.SENDER)
req.disagg_kv_sender = kv_sender_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
dest_tp_ranks=[self.tp_rank],
pp_rank=self.pp_rank,
)
self.queue.append(req)
부트스트랩 완료 감지
pop_bootstrapped 메서드는 핸드셰이크가 완료된 요청을 큐에서 꺼낸다.
def pop_bootstrapped(self, return_failed_reqs=False, ...):
polls = poll_and_all_reduce_attn_cp_tp_group(
[req.disagg_kv_sender for req in self.queue],
self.scheduler.attn_cp_cpu_group,
self.scheduler.attn_tp_cpu_group,
)
for i, (req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.WaitingForInput:
req.metadata_buffer_index = self.req_to_metadata_buffer_idx_allocator.alloc()
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
bootstrapped_reqs.append(req)
TP 전체 랭크에 대해 all-reduce로 상태를 동기화한 뒤, WaitingForInput 상태가 된 요청에 메타데이터 버퍼를 할당한다.
이벤트 루프: 스케줄러 통합
Prefill 서버의 메인 이벤트 루프는 요청 수신, 배치 구성, forward 실행, KV 전송을 반복한다.
def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
)
batch = self.get_next_disagg_prefill_batch_to_run()
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
self.process_disagg_prefill_inflight_queue()
KV 캐시 청크 전송
Prefill이 완료되면 KV 캐시를 페이지 단위로 Decode 서버에 전송한다. Chunked Prefill 시에는 중간 청크도 전송한다.
def send_kv_chunk(self: Scheduler, req: Req, last_chunk=False, end_idx=None):
page_size = self.token_to_kv_pool_allocator.page_size
start_idx = req.start_send_idx
end_idx = end_idx if end_idx is not None else min(len(req.fill_ids), len(req.origin_input_ids))
if not last_chunk:
end_idx = end_idx - end_idx % page_size # 페이지 경계 정렬
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, start_idx:end_idx
].cpu().numpy()
page_indices = kv_to_page_indices(kv_indices, page_size)
req.disagg_kv_sender.send(page_indices, state_indices)
마지막 청크가 아닌 경우, 부분 페이지를 다음 전송으로 미루기 위해 end_idx를 페이지 크기로 정렬한다.
배치 결과 처리
Prefill forward가 완료되면 첫 번째 출력 토큰과 메타데이터를 설정하고, KV 전송을 시작한다.
def process_batch_result_disagg_prefill(self, batch, result):
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_chunked <= 0:
req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req)
self.disagg_prefill_inflight_queue.append(req)
self.send_kv_chunk(req, last_chunk=True)
Inflight 큐 처리: 전송 완료 확인
전송 중인 요청들의 상태를 비동기로 폴링하여 완료된 요청의 KV 캐시를 해제한다.
def process_disagg_prefill_inflight_queue(self):
polls = poll_and_all_reduce_attn_cp_tp_group(
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
self.attn_cp_cpu_group, self.attn_tp_cpu_group,
)
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
if poll == KVPoll.Success:
release_kv_cache(req, self.tree_cache)
req.finished_reason = FINISH_LENGTH(length=0)
done_reqs.append(req)
전송이 성공하면 KV 캐시를 언락하고, 전송 지연 시간과 속도 메트릭을 기록한다.
설계 근거
max_new_tokens = 1 제약
Prefill 서버에서는 max_new_tokens를 1로 강제한다.
def _process_req(self, req: Req) -> None:
req.sampling_params.max_new_tokens = 1
이는 PrefillAdder의 메모리 추정이 정확하도록 보장한다. Prefill은 첫 번째 토큰만 생성하고, 나머지는 Decode 서버에서 처리한다.
Overlap 모드
event_loop_overlap_disagg_prefill은 현재 배치의 forward와 이전 배치의 결과 처리를 파이프라인으로 겹치는 최적화를 제공한다. 이를 통해 GPU idle 시간을 줄인다.
관련 포스트
- Prefill-Decode Disaggregation 개요
- Disaggregated Decode 서버: 디코드 전용 서버 구현
- Staging Buffer: KV 캐시 전송 버퍼 관리
참고
관련 포스트
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
- [sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서
- [sglang] NixlKVManager 성능 향상: 비동기 및 멀티스레드 KV 전송 도입
SGLang 의 다른글
- 이전글 [SGLang] Prefill-Decode Disaggregation 개요: PD 분리 아키텍처
- 현재글 : [SGLang] Disaggregated Prefill 서버: 프리필 전용 서버 구현
- 다음글 [SGLang] Disaggregated Decode 서버: 디코드 전용 서버 구현
댓글