본문으로 건너뛰기

[SGLang] Expert Parallel MoE: 분산 전문가 레이어 구현

들어가며

256개의 전문가를 가진 DeepSeek V3 같은 모델에서는 모든 전문가를 단일 GPU에 올릴 수 없다. Expert Parallel(EP)은 전문가를 여러 GPU에 분산 배치하고, All-to-All 통신으로 토큰을 올바른 전문가에 전달하는 병렬화 전략이다. SGLang의 DeepEPMoE 클래스는 DeepEP 라이브러리를 활용한 EP 구현을 제공한다.

소스 경로: python/sglang/srt/layers/moe/ep_moe/layer.py

구조도

GPU 0 (EP rank 0)              GPU 1 (EP rank 1)
┌──────────────────┐          ┌──────────────────┐
│ Expert 0, 1, ... │          │ Expert N/2, ...   │
│     (로컬)       │          │     (로컬)        │
└────────┬─────────┘          └────────┬──────────┘
         │                             │
         ▼                             ▼
   ┌──────────────────────────────────────┐
   │         All-to-All Dispatch          │
   │  (토큰 → 해당 전문가가 있는 GPU로)    │
   └──────────────────────────────────────┘
         │                             │
         ▼                             ▼
   Expert GEMM (로컬)           Expert GEMM (로컬)
         │                             │
         ▼                             ▼
   ┌──────────────────────────────────────┐
   │         All-to-All Combine           │
   │  (결과 → 원래 토큰이 있던 GPU로)      │
   └──────────────────────────────────────┘

핵심 코드 분석

1. DeepEPMoE: FusedMoE 상속 구조

DeepEPMoEFusedMoE를 상속하며, 양자화 설정에 따라 추가 초기화를 수행한다.

class DeepEPMoE(FusedMoE):
    def __init__(self, num_experts, top_k, hidden_size, intermediate_size, ...):
        super().__init__(...)
        if isinstance(quant_config, Fp8Config):
            self.use_block_quant = getattr(self.quant_method, "block_quant", False)
            self.use_fp8_w8a8 = True
            self.fp8_dtype = torch.float8_e4m3fn

FP8, W4AFp8 등 다양한 양자화 구성을 지원하며, DeepEP의 Normal/Low-Latency 두 가지 모드를 사용한다.

2. Forward: Dispatch → Core → Combine 파이프라인

forward_impl은 3단계 파이프라인을 실행한다.

def forward_impl(self, hidden_states, topk_output):
    dispatch_output = self.dispatcher.dispatch(
        hidden_states=hidden_states, topk_output=topk_output
    )
    combine_input = self.run_moe_core(dispatch_output)
    hidden_states = self.dispatcher.combine(combine_input=combine_input)
    return hidden_states

dispatcher는 DeepEP 기반 All-to-All 통신을 담당한다. dispatch에서 토큰을 전문가가 있는 GPU로 전송하고, combine에서 결과를 원래 GPU로 가져온다.

3. run_moe_core: 백엔드별 분기

run_moe_core는 하드웨어 플랫폼과 디스패치 포맷에 따라 실행 경로를 선택한다.

def run_moe_core(self, dispatch_output):
    if _use_aiter:
        output = self.forward_aiter(dispatch_output)
    elif _is_npu:
        output = self.forward_npu(dispatch_output)
    elif DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
        if self.use_w4afp8:
            output = self.forward_cutlass_w4afp8(dispatch_output)
    elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
        if get_moe_runner_backend().is_flashinfer_cutedsl():
            output = self.forward_flashinfer_cutedsl(dispatch_output)

AMD(AIter), NPU(Ascend), NVIDIA(CUTLASS/FlashInfer) 등 플랫폼별로 최적화된 경로가 존재한다.

4. AMD AIter 경로: 전문가 마스크

AMD GPU에서는 AIter의 fused_moe를 사용한다. DeepEP의 -1 인덱스(무효 전문가)를 AIter가 이해할 수 있는 마스크로 변환한다.

def forward_aiter(self, dispatch_output):
    topk_ids_copy = topk_ids.to(torch.int32)
    topk_ids_copy[topk_ids_copy == -1] = self.num_local_experts

    return fused_moe(
        hidden_states, self.w13_weight, self.w2_weight,
        topk_weights, topk_ids_copy,
        w1_scale=self.w13_weight_scale_inv,
        w2_scale=self.w2_weight_scale_inv,
        expert_mask=self.expert_mask,
    )

expert_mask[1, 1, ..., 1, 0] 형태로, 마지막 인덱스(num_local_experts)를 무효로 처리한다.

5. MoriEPMoE: 대안 EP 백엔드

MoriEPMoE는 AIter 기반의 또 다른 EP 구현이다. 전체 전문가에 대한 마스크를 사용하여 로컬 전문가만 활성화한다.

class MoriEPMoE(DeepEPMoE):
    def __init__(self, ...):
        self.expert_mask = torch.zeros(
            (self.num_experts), device=torch.cuda.current_device(), dtype=torch.int32,
        )
        expert_start_idx = self.moe_ep_rank * self.num_local_experts
        expert_end_idx = expert_start_idx + self.num_local_experts
        self.expert_mask[expert_start_idx:expert_end_idx] = 1

6. 백엔드 선택 로직

get_moe_impl_class 함수가 All-to-All 백엔드에 따라 적절한 MoE 클래스를 반환한다.

def get_moe_impl_class(quant_config):
    if get_moe_a2a_backend().is_mori():
        return MoriEPMoE
    if (get_moe_a2a_backend().is_deepep()
        or get_moe_a2a_backend().is_mooncake()
        or get_moe_a2a_backend().is_nixl()):
        return DeepEPMoE
    return FusedMoE

DeepEP Normal vs Low-Latency 모드

항목 Normal 모드 Low-Latency(LL) 모드
통신 방식 표준 All-to-All RDMA 기반 저지연
토큰 배치 연속 버퍼 마스크 기반
적합 시나리오 프리필(대량 토큰) 디코딩(소량 토큰)
DeepGEMM 필요 선택적 필수(NVIDIA)

설계 근거

DeepEPMoE가 FusedMoE를 상속하는 이유는 가중치 로딩, 양자화 메서드 관리 등의 공통 로직을 재사용하기 위해서다. EP 특화 로직(All-to-All 디스패치, 전문가 마스크)만 오버라이드한다. AMD/NPU/NVIDIA 각각에 대한 분기는 하드웨어별 최적 커널을 사용하기 위한 필연적 복잡성이다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글