본문으로 건너뛰기

[SGLang] Disaggregated Decode 서버: 디코드 전용 서버 구현

들어가며

Disaggregated Serving에서 Decode 서버는 Prefill 서버로부터 KV 캐시를 수신한 뒤 토큰을 순차적으로 생성하는 역할을 담당한다. Decode 서버는 4단계 큐 시스템(PreallocQueue → TransferQueue → WaitingQueue → RunningBatch)으로 요청을 관리하며, GPU 메모리 사전 할당과 비동기 KV 수신을 통해 처리량을 극대화한다.

구조도

  요청 도착 (Decode mode)
     │
     ▼
┌──────────────────────────────────────────┐
│ 1. PreallocQueue                         │
│    - KVReceiver 초기화                    │
│    - 핸드셰이크 + GPU 메모리 사전 할당      │
│    - Prefill 서버로 수신 준비 알림          │
└───────────────────┬──────────────────────┘
                    │ pop_preallocated()
                    ▼
┌──────────────────────────────────────────┐
│ 2. TransferQueue                         │
│    - KV 캐시 수신 대기 (non-blocking poll) │
│    - 완료 시 메타데이터 복원                │
└───────────────────┬──────────────────────┘
                    │
                    ▼
┌──────────────────────────────────────────┐
│ 3. WaitingQueue                          │
│    - PrebuiltExtendBatch 구성              │
│    - Prefill 생략, 메타데이터만 설정        │
└───────────────────┬──────────────────────┘
                    │
                    ▼
┌──────────────────────────────────────────┐
│ 4. RunningBatch                          │
│    - Decode forward 실행                  │
│    - 토큰 생성 루프                        │
└──────────────────────────────────────────┘

핵심 코드 분석

DecodeReqToTokenPool: 사전 할당 메모리 풀

Decode 서버는 일반 ReqToTokenPool과 달리 사전 할당 영역을 추가로 확보하여 KV 수신 중인 요청과 실행 중인 요청의 메모리를 분리한다.

class DecodeReqToTokenPool:
    def __init__(self, size, max_context_len, device, enable_memory_saver, pre_alloc_size):
        self.req_to_token = torch.zeros(
            (size + pre_alloc_size, max_context_len),
            dtype=torch.int32, device=device,
        )
        self.free_slots = list(range(size + pre_alloc_size))

pre_alloc_size만큼 추가 슬롯을 확보하여, #running <= size#pre_allocated + #transfer <= pre_alloc_size를 분리 관리한다.

DecodeRequest: 요청 래퍼

각 요청은 KVReceiver와 함께 래핑되어 전송 상태를 추적한다.

@dataclass
class DecodeRequest:
    req: Req
    kv_receiver: CommonKVReceiver
    waiting_for_input: bool = False
    metadata_buffer_index: int = -1

    @property
    def seqlen(self) -> int:
        return self.req.seqlen

DecodePreallocQueue: 사전 할당 큐

요청이 도착하면 KVReceiver를 생성하고, 핸드셰이크 완료 후 GPU 메모리를 사전 할당한다.

class DecodePreallocQueue:
    def add(self, req: Req, is_retracted=False):
        if self._check_if_req_exceed_kv_capacity(req):
            return
        decode_req = self._create_receiver_and_enqueue(req)
        prefill_dp_rank = self._resolve_prefill_dp_rank(req)
        if prefill_dp_rank is not None:
            decode_req.kv_receiver.init(prefill_dp_rank)
            return
        self.pending_reqs.append(decode_req)

Prefill 서버의 DP rank를 즉시 확인할 수 없는 경우, pending_reqs에 넣고 나중에 배치로 해결한다.

KVReceiver 생성과 핸드셰이크

Decode 서버는 Prefill 서버의 정보를 받아 수신기를 초기화한다.

def _create_receiver_and_enqueue(self, req):
    backend = (TransferBackend.FAKE
               if _is_fake_transfer(req, self.scheduler.server_args)
               else self.transfer_backend)
    kv_receiver_class = get_kv_class(backend, KVClassType.RECEIVER)
    kv_receiver = kv_receiver_class(
        mgr=self.kv_manager,
        bootstrap_addr=_bootstrap_addr(req),
        bootstrap_room=req.bootstrap_room,
    )
    decode_req = DecodeRequest(req=req, kv_receiver=kv_receiver)
    self.queue.append(decode_req)
    return decode_req

핸드셰이크 상태 업데이트

폴링으로 핸드셰이크 상태를 확인하고, 완료된 요청을 표시한다.

def _update_handshake_waiters(self, rids_to_check=None):
    polls = poll_and_all_reduce(
        [decode_req.kv_receiver for decode_req in self.queue],
        self.gloo_group,
    )
    for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
        if poll == KVPoll.WaitingForInput:
            decode_req.waiting_for_input = True
            decode_req.req.time_stats.set_bootstrap_done_time()
        elif poll == KVPoll.Failed:
            prepare_abort(decode_req.req, error_message, ...)

Prefill DP Rank 해석

다중 DP 환경에서 요청이 어떤 Prefill DP rank에서 왔는지 확인해야 한다.

def _resolve_prefill_dp_rank(self, req):
    if req.disagg_prefill_dp_rank is not None:
        return req.disagg_prefill_dp_rank
    prefill_info = self.kv_manager.prefill_info_table.get(_bootstrap_addr(req))
    if prefill_info is None:
        return None
    if prefill_info.dp_size == 1:
        return 0
    if prefill_info.follow_bootstrap_room:
        return req.bootstrap_room % prefill_info.dp_size
    return None

메모리 할당 전략

사전 할당 시 현재 할당 가능한 토큰 수를 계산하고, 요청의 입력 길이와 예약된 Decode 토큰 수를 고려한다.

def pop_preallocated(self, rids_to_check=None):
    retractable_tokens = sum(
        len(r.origin_input_ids) + len(r.output_ids)
        for r in self.scheduler.running_batch.reqs
    )
    allocatable_tokens = self._allocatable_tokens(
        retractable_tokens=retractable_tokens, count_retracted=True
    )
    for i, decode_req in enumerate(self.queue):
        origin_input_len = len(decode_req.req.origin_input_ids)
        required_tokens = origin_input_len + self.num_reserved_decode_tokens
        if required_tokens > allocatable_tokens:
            break

Retraction 메커니즘

GPU 메모리가 부족하면 실행 중인 요청의 KV 캐시를 CPU로 퇴피(retract)시키고, 나중에 다시 로드한다.

def resume_retracted_reqs(self, rids_to_check=None):
    for i, req in enumerate(self.retracted_queue):
        required_tokens = (len(req.origin_input_ids) + len(req.output_ids)
                          + self.num_reserved_decode_tokens)
        if required_tokens > allocatable_tokens:
            break
        req.is_retracted = False
        self._pre_alloc(req)
        req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)

설계 근거

사전 할당의 필요성

Decode 서버가 Prefill 서버보다 먼저 GPU 메모리를 할당해야 하는 이유는, RDMA 전송 시 목적지 메모리 주소가 필요하기 때문이다. 메모리 할당과 핸드셰이크를 병렬화하여 Prefill 서버의 대기 시간을 최소화한다.

Staging Buffer 연동

heterogeneous TP(서로 다른 TP 크기) 환경에서는 Staging Buffer를 통해 RDMA 요청 수를 줄인다. Decode 서버의 PreallocQueue 초기화 시 staging handler도 함께 설정된다.

if self.enable_staging:
    self.transfer_queue._init_staging_handler(self.kv_manager)

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글