[SGLang] Disaggregation 커넥터: Mooncake, NIXL, MORI 전송 엔진
들어가며
SGLang의 Disaggregated Serving은 KV 캐시를 Prefill → Decode 서버로 전송하기 위해 플러거블한 커넥터 아키텍처를 제공한다. 현재 Mooncake, NIXL, MORI 세 가지 전송 엔진을 지원하며, 각각의 RDMA/GPU Direct 구현 위에서 동일한 인터페이스를 제공한다.
구조도
┌────────────────────────────────────────────────────┐
│ 공통 인터페이스 │
│ BaseKVManager / BaseKVSender / BaseKVReceiver │
│ (python/sglang/srt/disaggregation/base/conn.py) │
└────────────────────┬───────────────────────────────┘
│ 상속
┌───────────────┼───────────────┐
▼ ▼ ▼
┌──────────┐ ┌──────────┐ ┌──────────┐
│ Mooncake │ │ NIXL │ │ MORI │
│ conn.py │ │ conn.py │ │ conn.py │
├──────────┤ ├──────────┤ ├──────────┤
│ ZMQ + │ │ ZMQ + │ │ ZMQ + │
│ Transfer │ │ NIXL │ │ MORI │
│ Engine │ │ Agent │ │ IOEngine │
└──────────┘ └──────────┘ └──────────┘
핵심 코드 분석
공통 인터페이스: Base 클래스
python/sglang/srt/disaggregation/base/conn.py에서 모든 커넥터가 구현해야 하는 추상 인터페이스를 정의한다.
class BaseKVSender(ABC):
@abstractmethod
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
"""수신 서버에 KV 인덱스 수와 aux 인덱스를 알린다."""
...
@abstractmethod
def send(self, kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None):
"""주어진 인덱스의 KV 캐시를 Decode 서버로 전송한다."""
...
@abstractmethod
def poll(self) -> KVPoll:
"""전송 상태를 확인한다."""
...
class BaseKVReceiver(ABC):
@abstractmethod
def init(self, prefill_dp_rank: int):
"""부트스트랩 메타데이터를 해석하고 수신 준비를 한다."""
...
@abstractmethod
def send_metadata(self, kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None, ...):
"""Prefill 서버에 수신 준비된 KV 인덱스를 알린다."""
...
Mooncake 커넥터
Mooncake는 Moonshot AI에서 개발한 Transfer Engine을 사용한다. ZMQ 기반 부트스트랩과 RDMA 전송을 조합한다.
# python/sglang/srt/disaggregation/mooncake/conn.py
@dataclasses.dataclass
class TransferInfo:
room: int
endpoint: str
dst_port: int
mooncake_session_id: str
dst_kv_indices: npt.NDArray[np.int32]
dst_aux_index: int
dst_state_indices: List[int]
required_dst_info_num: int
is_dummy: bool
staging: Optional[StagingTransferInfo] = None
Mooncake의 TransferInfo는 세션 ID 기반으로 원격 GPU 메모리에 직접 쓰는 RDMA 연결을 관리한다. staging 필드는 heterogeneous TP 환경에서 Staging Buffer를 통한 전송을 지원한다.
@dataclasses.dataclass
class KVArgsRegisterInfo:
room: str
endpoint: str
dst_port: int
mooncake_session_id: str
dst_kv_ptrs: list[int] # 원격 KV 캐시 포인터
dst_aux_ptrs: list[int] # 원격 메타데이터 포인터
dst_state_data_ptrs: list[int] # 원격 상태 데이터 포인터
dst_tp_rank: int
dst_attn_tp_size: int
dst_kv_item_len: int
Decode 서버가 자신의 GPU 메모리 포인터를 Prefill 서버에 등록하여, RDMA one-sided write가 가능하도록 한다.
NIXL 커넥터
NVIDIA의 NIXL(NVIDIA Inference Transfer Library)을 사용하는 커넥터이다.
# python/sglang/srt/disaggregation/nixl/conn.py
@dataclasses.dataclass
class TransferInfo:
room: int
endpoint: str
dst_port: int
agent_name: str # NIXL 에이전트 식별자
dst_kv_indices: npt.NDArray[np.int32]
dst_aux_index: int
required_dst_info_num: int
dst_state_indices: List[int]
def is_dummy(self):
return self.dst_kv_indices.size == 0
NIXL은 agent_name과 agent_metadata(직렬화된 바이너리)를 사용하여 NIXL 에이전트 간 연결을 수립한다.
@dataclasses.dataclass
class TransferStatus:
received_kvs_per_pp: Dict[int, Set[int]] = dataclasses.field(
default_factory=lambda: defaultdict(set))
expected_kvs_per_pp: Dict[int, int] = dataclasses.field(default_factory=dict)
num_pp_ranks_expected: Optional[int] = None
received_aux: bool = False
received_state_per_pp: Set[int] = dataclasses.field(default_factory=set)
expects_state: bool = False
NIXL의 TransferStatus는 PP(Pipeline Parallel) rank별로 KV 청크 수신을 개별 추적하여, 멀티 PP 환경에서의 정확한 완료 감지를 지원한다.
MORI 커넥터
MORI는 RDMA 기반 IOEngine을 사용하는 고성능 전송 엔진이다.
# python/sglang/srt/disaggregation/mori/conn.py
from mori.io import IOEngine, IOEngineConfig, MemoryDesc, MemoryLocationType, EngineDesc
@dataclasses.dataclass
class TransferInfo:
room: int
endpoint: str
dst_port: int
engine_key: str # MORI IOEngine 식별자
dst_kv_indices: npt.NDArray[np.int32]
dst_aux_index: int
required_dst_info_num: int
is_dummy: bool
MORI는 MemoryDesc(메모리 디스크립터)를 사용하여 원격 메모리 영역을 기술한다. msgspec으로 직렬화하여 전송한다.
@dataclasses.dataclass
class KVArgsRegisterInfo:
endpoint: str
dst_port: int
engine_desc: EngineDesc # MORI 엔진 디스크립터
dst_kv_mem_descs: List[MemoryDesc] # KV 메모리 디스크립터
dst_aux_mem_descs: List[MemoryDesc] # 메타데이터 메모리 디스크립터
dst_state_mem_descs: List[MemoryDesc] # 상태 메모리 디스크립터
gpu_id: int
decode_tp_size: int
decode_tp_rank: int
dst_kv_item_len: int
커넥터 비교
┌──────────┬──────────────────┬─────────────────┬──────────────────┐
│ │ Mooncake │ NIXL │ MORI │
├──────────┼──────────────────┼─────────────────┼──────────────────┤
│ 전송방식 │ RDMA (custom) │ NIXL Agent │ RDMA (IOEngine) │
│ 세션관리 │ session_id 기반 │ agent_name 기반 │ engine_key 기반 │
│ 메모리기술│ raw ptr + lens │ agent_metadata │ MemoryDesc │
│ PP 지원 │ required_dst_num │ per-PP tracking │ required_dst_num │
│ Staging │ 지원 (optional) │ 미지원 │ 미지원 │
│ 직렬화 │ struct.pack │ struct.pack │ msgspec.msgpack │
└──────────┴──────────────────┴─────────────────┴──────────────────┘
공통 패턴: ZMQ 부트스트랩
세 커넥터 모두 ZMQ(ZeroMQ)를 사용하여 부트스트랩 핸드셰이크를 수행한다. 부트스트랩 서버가 Prefill과 Decode 서버 간의 초기 연결을 중개하고, 이후 실제 데이터 전송은 각 커넥터의 네이티브 RDMA 채널로 수행된다.
CP(Context Parallel) 지원
모든 커넥터에서 CP rank에 따른 KV 인덱스 필터링이 공통 유틸리티로 제공된다.
def filter_kv_indices_for_cp_rank(kv_mgr, kv_indices, index_slice):
rank_page_indices = page_indices_to_cp_rank_page_indices(
page_indices=kv_indices, total_pages=total_pages,
cp_rank=kv_mgr.attn_cp_rank, cp_size=kv_mgr.attn_cp_size,
)
관련 포스트
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] KV Cache Offloading: Decode 중 메모리 오프로딩
- 현재글 : [SGLang] Disaggregation 커넥터: Mooncake, NIXL, MORI 전송 엔진
- 다음글 [SGLang] Staging Buffer: KV 캐시 전송 버퍼 관리
댓글