본문으로 건너뛰기

[SGLang] Prefill-Decode Disaggregation 개요: PD 분리 아키텍처

들어가며

LLM 추론에서 Prefill(프롬프트 처리)과 Decode(토큰 생성)는 근본적으로 다른 연산 특성을 갖는다. Prefill은 compute-bound이고, Decode는 memory-bound이다. SGLang은 이 두 단계를 물리적으로 분리된 서버에서 실행하는 Prefill-Decode Disaggregation 아키텍처를 구현하여, 각 단계에 최적화된 GPU 할당과 스케줄링을 가능하게 한다.

구조도

                    ┌─────────────────┐
                    │  Load Balancer  │
                    │  (Router)       │
                    └────────┬────────┘
                             │
              ┌──────────────┼──────────────┐
              ▼              │              ▼
   ┌──────────────────┐      │    ┌──────────────────┐
   │  Prefill Server  │      │    │  Prefill Server  │
   │  (PREFILL mode)  │      │    │  (PREFILL mode)  │
   │  ┌────────────┐  │      │    │  ┌────────────┐  │
   │  │ KV Sender  │──┼──────┼────┼──│ KV Sender  │  │
   │  └────────────┘  │      │    │  └────────────┘  │
   └──────────────────┘      │    └──────────────────┘
              │              │              │
              │  KV Cache Transfer (RDMA)   │
              │              │              │
              ▼              ▼              ▼
   ┌──────────────────┐    ┌──────────────────┐
   │  Decode Server   │    │  Decode Server   │
   │  (DECODE mode)   │    │  (DECODE mode)   │
   │  ┌────────────┐  │    │  ┌────────────┐  │
   │  │KV Receiver │  │    │  │KV Receiver │  │
   │  └────────────┘  │    │  └────────────┘  │
   └──────────────────┘    └──────────────────┘

핵심 코드 분석

DisaggregationMode: 서버 역할 정의

python/sglang/srt/disaggregation/utils.py에서 서버의 역할을 enum으로 정의한다.

class DisaggregationMode(Enum):
    NULL = "null"
    PREFILL = "prefill"
    DECODE = "decode"

NULL은 disaggregation을 사용하지 않는 통합 서빙 모드이고, PREFILLDECODE는 각각 전용 서버 모드를 나타낸다.

TransferBackend: 전송 엔진 선택

KV 캐시 전송을 위한 백엔드를 선택할 수 있다. Mooncake, NIXL, MORI 등 다양한 RDMA 기반 전송 엔진을 지원한다.

class TransferBackend(Enum):
    MOONCAKE = "mooncake"
    MORI = "mori"
    NIXL = "nixl"
    ASCEND = "ascend"
    FAKE = "fake"

KVArgs: 전송 메타데이터 구조

python/sglang/srt/disaggregation/base/conn.pyKVArgs는 KV 캐시 전송에 필요한 모든 메타데이터를 정의한다.

class KVArgs:
    engine_rank: int
    kv_data_ptrs: List[int]      # KV 캐시 GPU 메모리 포인터
    kv_data_lens: List[int]      # 각 버퍼의 크기
    kv_item_lens: List[int]      # 항목당 크기
    aux_data_ptrs: List[int]     # 메타데이터 버퍼 포인터
    page_size: int               # 페이지 크기
    pp_rank: int                 # Pipeline Parallel 랭크
    system_dp_rank: int          # Data Parallel 랭크

이 구조체는 Prefill과 Decode 서버 모두에서 초기화되어, 양쪽이 서로의 메모리 레이아웃을 파악할 수 있게 한다.

KVPoll: 전송 상태 머신

전송 상태는 5단계 상태 머신으로 관리된다.

class KVPoll:
    Failed = 0
    Bootstrapping = 1
    WaitingForInput = 2
    Transferring = 3
    Success = 4

Bootstrapping에서 핸드셰이크가 완료되면 WaitingForInput으로 전환되고, 실제 전송이 시작되면 Transferring, 완료되면 Success가 된다.

분산 폴링과 동기화

TP(Tensor Parallel) 환경에서 모든 랭크가 동일한 전송 상태를 관찰하도록 all-reduce를 수행한다.

def poll_and_all_reduce(pollers, gloo_group: dist.ProcessGroup):
    polls = [int(poller.poll()) for poller in pollers]
    tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
    dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
    return tensor_to_reduce.tolist()

MIN 연산을 사용하여 가장 느린 랭크의 상태에 맞춘다. 한 랭크라도 Failed(0)이면 전체가 실패로 처리된다.

MetadataBuffers: 첫 토큰 메타데이터 전송

Prefill 서버에서 생성한 첫 번째 출력 토큰과 관련 메타데이터를 Decode 서버로 전달하기 위한 버퍼이다.

class MetadataBuffers:
    def __init__(self, size, hidden_size, hidden_states_dtype, ...):
        self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
        self.output_hidden_states = torch.zeros(
            (size, hidden_size), dtype=hidden_states_dtype, device=device
        )
        self.bootstrap_room = torch.zeros(
            (size, 8), dtype=bootstrap_room_dtype, device=device
        )

RDMA 최소 전송 크기(64Bytes)를 고려하여 패딩된 텐서를 사용한다.

동적 백엔드 로딩

get_kv_class 함수는 전송 백엔드에 따라 적절한 Manager, Sender, Receiver 클래스를 동적으로 반환한다.

def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
    if transfer_backend == TransferBackend.MOONCAKE:
        class_mapping = {
            KVClassType.MANAGER: MooncakeKVManager,
            KVClassType.SENDER: MooncakeKVSender,
            KVClassType.RECEIVER: MooncakeKVReceiver,
            KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
        }
        return class_mapping.get(class_type)

설계 근거

왜 PD를 분리하는가?

  1. GPU 활용률 극대화: Prefill은 높은 compute 활용률, Decode는 높은 memory bandwidth 활용률을 요구한다. 분리하면 각 GPU를 최적 워크로드에 할당할 수 있다.
  2. 간섭 제거: 통합 서빙에서는 긴 Prefill이 Decode 지연을 유발한다. 분리하면 TTFT(Time To First Token)와 TPOT(Time Per Output Token)를 독립적으로 최적화할 수 있다.
  3. 독립 스케일링: Prefill과 Decode 서버의 수를 워크로드 특성에 맞게 독립적으로 조절할 수 있다.

페이지 기반 KV 캐시 전송

KV 캐시를 페이지 단위로 관리하여 전송 효율을 높인다.

def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
    if page_size == 1:
        return kv_indices
    return kv_indices[::page_size] // page_size

페이지 크기가 1이 아닌 경우, 토큰 인덱스를 페이지 인덱스로 변환하여 전송 단위를 줄인다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글