본문으로 건너뛰기

[SGLang] Disaggregation 커넥터: Mooncake, NIXL, MORI 전송 엔진

들어가며

SGLang의 Disaggregated Serving은 KV 캐시를 Prefill → Decode 서버로 전송하기 위해 플러거블한 커넥터 아키텍처를 제공한다. 현재 Mooncake, NIXL, MORI 세 가지 전송 엔진을 지원하며, 각각의 RDMA/GPU Direct 구현 위에서 동일한 인터페이스를 제공한다.

구조도

┌────────────────────────────────────────────────────┐
│                  공통 인터페이스                      │
│  BaseKVManager / BaseKVSender / BaseKVReceiver      │
│  (python/sglang/srt/disaggregation/base/conn.py)   │
└────────────────────┬───────────────────────────────┘
                     │ 상속
     ┌───────────────┼───────────────┐
     ▼               ▼               ▼
┌──────────┐  ┌──────────┐  ┌──────────┐
│ Mooncake │  │   NIXL   │  │   MORI   │
│ conn.py  │  │ conn.py  │  │ conn.py  │
├──────────┤  ├──────────┤  ├──────────┤
│ ZMQ +    │  │ ZMQ +    │  │ ZMQ +    │
│ Transfer │  │ NIXL     │  │ MORI     │
│ Engine   │  │ Agent    │  │ IOEngine │
└──────────┘  └──────────┘  └──────────┘

핵심 코드 분석

공통 인터페이스: Base 클래스

python/sglang/srt/disaggregation/base/conn.py에서 모든 커넥터가 구현해야 하는 추상 인터페이스를 정의한다.

class BaseKVSender(ABC):
    @abstractmethod
    def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
        """수신 서버에 KV 인덱스 수와 aux 인덱스를 알린다."""
        ...

    @abstractmethod
    def send(self, kv_indices: npt.NDArray[np.int32],
             state_indices: Optional[List[int]] = None):
        """주어진 인덱스의 KV 캐시를 Decode 서버로 전송한다."""
        ...

    @abstractmethod
    def poll(self) -> KVPoll:
        """전송 상태를 확인한다."""
        ...
class BaseKVReceiver(ABC):
    @abstractmethod
    def init(self, prefill_dp_rank: int):
        """부트스트랩 메타데이터를 해석하고 수신 준비를 한다."""
        ...

    @abstractmethod
    def send_metadata(self, kv_indices: npt.NDArray[np.int32],
                      aux_index: Optional[int] = None, ...):
        """Prefill 서버에 수신 준비된 KV 인덱스를 알린다."""
        ...

Mooncake 커넥터

Mooncake는 Moonshot AI에서 개발한 Transfer Engine을 사용한다. ZMQ 기반 부트스트랩과 RDMA 전송을 조합한다.

# python/sglang/srt/disaggregation/mooncake/conn.py

@dataclasses.dataclass
class TransferInfo:
    room: int
    endpoint: str
    dst_port: int
    mooncake_session_id: str
    dst_kv_indices: npt.NDArray[np.int32]
    dst_aux_index: int
    dst_state_indices: List[int]
    required_dst_info_num: int
    is_dummy: bool
    staging: Optional[StagingTransferInfo] = None

Mooncake의 TransferInfo는 세션 ID 기반으로 원격 GPU 메모리에 직접 쓰는 RDMA 연결을 관리한다. staging 필드는 heterogeneous TP 환경에서 Staging Buffer를 통한 전송을 지원한다.

@dataclasses.dataclass
class KVArgsRegisterInfo:
    room: str
    endpoint: str
    dst_port: int
    mooncake_session_id: str
    dst_kv_ptrs: list[int]         # 원격 KV 캐시 포인터
    dst_aux_ptrs: list[int]        # 원격 메타데이터 포인터
    dst_state_data_ptrs: list[int] # 원격 상태 데이터 포인터
    dst_tp_rank: int
    dst_attn_tp_size: int
    dst_kv_item_len: int

Decode 서버가 자신의 GPU 메모리 포인터를 Prefill 서버에 등록하여, RDMA one-sided write가 가능하도록 한다.

NIXL 커넥터

NVIDIA의 NIXL(NVIDIA Inference Transfer Library)을 사용하는 커넥터이다.

# python/sglang/srt/disaggregation/nixl/conn.py

@dataclasses.dataclass
class TransferInfo:
    room: int
    endpoint: str
    dst_port: int
    agent_name: str                    # NIXL 에이전트 식별자
    dst_kv_indices: npt.NDArray[np.int32]
    dst_aux_index: int
    required_dst_info_num: int
    dst_state_indices: List[int]

    def is_dummy(self):
        return self.dst_kv_indices.size == 0

NIXL은 agent_nameagent_metadata(직렬화된 바이너리)를 사용하여 NIXL 에이전트 간 연결을 수립한다.

@dataclasses.dataclass
class TransferStatus:
    received_kvs_per_pp: Dict[int, Set[int]] = dataclasses.field(
        default_factory=lambda: defaultdict(set))
    expected_kvs_per_pp: Dict[int, int] = dataclasses.field(default_factory=dict)
    num_pp_ranks_expected: Optional[int] = None
    received_aux: bool = False
    received_state_per_pp: Set[int] = dataclasses.field(default_factory=set)
    expects_state: bool = False

NIXL의 TransferStatus는 PP(Pipeline Parallel) rank별로 KV 청크 수신을 개별 추적하여, 멀티 PP 환경에서의 정확한 완료 감지를 지원한다.

MORI 커넥터

MORI는 RDMA 기반 IOEngine을 사용하는 고성능 전송 엔진이다.

# python/sglang/srt/disaggregation/mori/conn.py
from mori.io import IOEngine, IOEngineConfig, MemoryDesc, MemoryLocationType, EngineDesc

@dataclasses.dataclass
class TransferInfo:
    room: int
    endpoint: str
    dst_port: int
    engine_key: str          # MORI IOEngine 식별자
    dst_kv_indices: npt.NDArray[np.int32]
    dst_aux_index: int
    required_dst_info_num: int
    is_dummy: bool

MORI는 MemoryDesc(메모리 디스크립터)를 사용하여 원격 메모리 영역을 기술한다. msgspec으로 직렬화하여 전송한다.

@dataclasses.dataclass
class KVArgsRegisterInfo:
    endpoint: str
    dst_port: int
    engine_desc: EngineDesc          # MORI 엔진 디스크립터
    dst_kv_mem_descs: List[MemoryDesc]   # KV 메모리 디스크립터
    dst_aux_mem_descs: List[MemoryDesc]  # 메타데이터 메모리 디스크립터
    dst_state_mem_descs: List[MemoryDesc] # 상태 메모리 디스크립터
    gpu_id: int
    decode_tp_size: int
    decode_tp_rank: int
    dst_kv_item_len: int

커넥터 비교

┌──────────┬──────────────────┬─────────────────┬──────────────────┐
│          │    Mooncake      │     NIXL        │     MORI         │
├──────────┼──────────────────┼─────────────────┼──────────────────┤
│ 전송방식  │ RDMA (custom)    │ NIXL Agent      │ RDMA (IOEngine)  │
│ 세션관리  │ session_id 기반   │ agent_name 기반  │ engine_key 기반   │
│ 메모리기술│ raw ptr + lens   │ agent_metadata  │ MemoryDesc       │
│ PP 지원  │ required_dst_num │ per-PP tracking │ required_dst_num │
│ Staging  │ 지원 (optional)  │ 미지원           │ 미지원            │
│ 직렬화   │ struct.pack      │ struct.pack     │ msgspec.msgpack  │
└──────────┴──────────────────┴─────────────────┴──────────────────┘

공통 패턴: ZMQ 부트스트랩

세 커넥터 모두 ZMQ(ZeroMQ)를 사용하여 부트스트랩 핸드셰이크를 수행한다. 부트스트랩 서버가 Prefill과 Decode 서버 간의 초기 연결을 중개하고, 이후 실제 데이터 전송은 각 커넥터의 네이티브 RDMA 채널로 수행된다.

CP(Context Parallel) 지원

모든 커넥터에서 CP rank에 따른 KV 인덱스 필터링이 공통 유틸리티로 제공된다.

def filter_kv_indices_for_cp_rank(kv_mgr, kv_indices, index_slice):
    rank_page_indices = page_indices_to_cp_rank_page_indices(
        page_indices=kv_indices, total_pages=total_pages,
        cp_rank=kv_mgr.attn_cp_rank, cp_size=kv_mgr.attn_cp_size,
    )

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글