본문으로 건너뛰기

[SGLang] Fused MoE (Triton): 라우팅과 전문가 연산의 융합

들어가며

Mixture of Experts(MoE) 모델에서 가장 큰 성능 병목은 라우팅 결과에 따라 토큰을 전문가에 배분하고, 각 전문가의 GEMM을 수행한 후 결과를 합산하는 과정이다. SGLang의 FusedMoE 클래스는 이 전체 파이프라인을 하나의 모듈로 융합하여 커널 호출 오버헤드와 메모리 이동을 최소화한다.

소스 경로: python/sglang/srt/layers/moe/fused_moe_triton/layer.py

구조도

┌─────────────────────────────────────────────┐
│                FusedMoE Module               │
├─────────────────────────────────────────────┤
│  weights: w13_weight (gate+up), w2_weight   │
│           (down projection)                  │
├──────────┬──────────────┬───────────────────┤
│ TopK     │  Dispatcher  │  QuantMethod      │
│ (라우팅) │  (토큰 분배)  │  (전문가 연산)     │
└──────────┴──────────────┴───────────────────┘
         │          │              │
         ▼          ▼              ▼
   topk_weights  dispatch()    apply()
   topk_ids      ─────────►  GEMM1 → SiLU
                              → GEMM2
                  combine()  ◄─────────
                     │
                     ▼
              final_hidden_states

핵심 코드 분석

1. FusedMoE 초기화: Expert Parallel 분할

FusedMoE.__init__에서 전문가를 EP(Expert Parallel) 랭크별로 분할한다. 공유 전문가(shared expert)와 라우팅 전문가를 구분하여 로컬 전문가 수를 계산한다.

assert (num_experts - num_shared_slots) % self.moe_ep_size == 0
self._num_global_routed = num_experts - num_shared_slots
self._num_local_routed = self._num_global_routed // self.moe_ep_size
self.num_local_experts = self._num_local_routed + num_fused_shared_experts

DeepSeek V3처럼 공유 전문가를 MoE 레이어에 융합하는 경우, EP 백엔드 종류에 따라 슬롯 수 계산이 달라진다.

2. Dispatcher 생성: 백엔드별 분기

create_moe_dispatcher 함수는 All-to-All 백엔드 설정에 따라 적절한 디스패처를 선택한다.

def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
    a2a_backend = get_moe_a2a_backend()
    if a2a_backend.is_none():
        return StandardDispatcher(moe_runner_config)
    elif (a2a_backend.is_deepep() or a2a_backend.is_mooncake()
          or a2a_backend.is_mori() or a2a_backend.is_nixl()):
        return MaybeTboDeepEPDispatcher(...)
    elif a2a_backend.is_flashinfer():
        return FlashinferDispatcher(...)

StandardDispatcher는 단일 GPU 환경, MaybeTboDeepEPDispatcher는 DeepEP/Mooncake 기반 분산 환경에서 사용된다.

3. Forward 파이프라인: Dispatch → Core → Combine

forward_impl 메서드가 MoE의 핵심 파이프라인을 실행한다.

def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
    dispatch_output = self.dispatcher.dispatch(
        hidden_states=hidden_states, topk_output=topk_output
    )
    combine_input = self.run_moe_core(dispatch_output=dispatch_output)
    final_hidden_states = self.dispatcher.combine(combine_input=combine_input)

    if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
        final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
    return final_hidden_states

3단계 구조가 명확하다: (1) 토큰을 전문가별로 분배, (2) 전문가 GEMM 실행, (3) 결과를 원래 토큰 순서로 재조합. TP/EP 환경에서는 마지막에 all-reduce를 수행한다.

4. run_moe_core: 양자화 메서드에 위임

실제 GEMM 연산은 quant_method.apply()에 위임된다.

def run_moe_core(self, dispatch_output: DispatchOutput) -> CombineInput:
    return self.quant_method.apply(
        layer=self,
        dispatch_output=dispatch_output,
    )

quant_methodUnquantizedFusedMoEMethod, Fp8MoEMethod, ModelOptNvFp4FusedMoEMethod 등 양자화 설정에 따라 달라진다. 이 위임 패턴 덕분에 FusedMoE 클래스는 양자화 방식에 독립적이다.

5. 가중치 로딩: w13과 w2의 분리

MoE의 가중치는 w13_weight(gate_proj + up_proj 융합)와 w2_weight(down_proj)로 구분된다. TP 샤딩 시 w13은 출력 차원, w2는 입력 차원을 기준으로 분할한다.

SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}

Triton 커널 사용 시에는 가중치가 전치(transpose)되어 저장된다.

if not self.use_presharded_weights:
    if not is_bias and self.use_triton_kernels:
        loaded_weight = loaded_weight.transpose(-2, -1)

6. CUDA Graph 지원

피스와이즈 CUDA 그래프 모드에서는 별도의 최적화 경로를 사용한다.

def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
    if is_in_piecewise_cuda_graph():
        if TopKOutputChecker.format_is_standard(topk_output):
            return moe_forward_piecewise_cuda_graph_impl(...)

설계 근거

설계 결정 이유
Dispatch/Core/Combine 3단계 분리 분산 통신(Dispatch/Combine)과 연산(Core)을 독립적으로 최적화
quant_method 위임 FP16/FP8/FP4 등 양자화별 커널을 플러그인 형태로 교체
w13 융합 gate_proj와 up_proj를 하나의 GEMM으로 실행하여 커널 호출 절반 감소
Triton 전치 저장 Triton 커널의 메모리 접근 패턴에 최적화

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글