[SGLang] Fused MoE (Triton): 라우팅과 전문가 연산의 융합
들어가며
Mixture of Experts(MoE) 모델에서 가장 큰 성능 병목은 라우팅 결과에 따라 토큰을 전문가에 배분하고, 각 전문가의 GEMM을 수행한 후 결과를 합산하는 과정이다. SGLang의 FusedMoE 클래스는 이 전체 파이프라인을 하나의 모듈로 융합하여 커널 호출 오버헤드와 메모리 이동을 최소화한다.
소스 경로: python/sglang/srt/layers/moe/fused_moe_triton/layer.py
구조도
┌─────────────────────────────────────────────┐
│ FusedMoE Module │
├─────────────────────────────────────────────┤
│ weights: w13_weight (gate+up), w2_weight │
│ (down projection) │
├──────────┬──────────────┬───────────────────┤
│ TopK │ Dispatcher │ QuantMethod │
│ (라우팅) │ (토큰 분배) │ (전문가 연산) │
└──────────┴──────────────┴───────────────────┘
│ │ │
▼ ▼ ▼
topk_weights dispatch() apply()
topk_ids ─────────► GEMM1 → SiLU
→ GEMM2
combine() ◄─────────
│
▼
final_hidden_states
핵심 코드 분석
1. FusedMoE 초기화: Expert Parallel 분할
FusedMoE.__init__에서 전문가를 EP(Expert Parallel) 랭크별로 분할한다. 공유 전문가(shared expert)와 라우팅 전문가를 구분하여 로컬 전문가 수를 계산한다.
assert (num_experts - num_shared_slots) % self.moe_ep_size == 0
self._num_global_routed = num_experts - num_shared_slots
self._num_local_routed = self._num_global_routed // self.moe_ep_size
self.num_local_experts = self._num_local_routed + num_fused_shared_experts
DeepSeek V3처럼 공유 전문가를 MoE 레이어에 융합하는 경우, EP 백엔드 종류에 따라 슬롯 수 계산이 달라진다.
2. Dispatcher 생성: 백엔드별 분기
create_moe_dispatcher 함수는 All-to-All 백엔드 설정에 따라 적절한 디스패처를 선택한다.
def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
a2a_backend = get_moe_a2a_backend()
if a2a_backend.is_none():
return StandardDispatcher(moe_runner_config)
elif (a2a_backend.is_deepep() or a2a_backend.is_mooncake()
or a2a_backend.is_mori() or a2a_backend.is_nixl()):
return MaybeTboDeepEPDispatcher(...)
elif a2a_backend.is_flashinfer():
return FlashinferDispatcher(...)
StandardDispatcher는 단일 GPU 환경, MaybeTboDeepEPDispatcher는 DeepEP/Mooncake 기반 분산 환경에서 사용된다.
3. Forward 파이프라인: Dispatch → Core → Combine
forward_impl 메서드가 MoE의 핵심 파이프라인을 실행한다.
def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
dispatch_output = self.dispatcher.dispatch(
hidden_states=hidden_states, topk_output=topk_output
)
combine_input = self.run_moe_core(dispatch_output=dispatch_output)
final_hidden_states = self.dispatcher.combine(combine_input=combine_input)
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
3단계 구조가 명확하다: (1) 토큰을 전문가별로 분배, (2) 전문가 GEMM 실행, (3) 결과를 원래 토큰 순서로 재조합. TP/EP 환경에서는 마지막에 all-reduce를 수행한다.
4. run_moe_core: 양자화 메서드에 위임
실제 GEMM 연산은 quant_method.apply()에 위임된다.
def run_moe_core(self, dispatch_output: DispatchOutput) -> CombineInput:
return self.quant_method.apply(
layer=self,
dispatch_output=dispatch_output,
)
quant_method는 UnquantizedFusedMoEMethod, Fp8MoEMethod, ModelOptNvFp4FusedMoEMethod 등 양자화 설정에 따라 달라진다. 이 위임 패턴 덕분에 FusedMoE 클래스는 양자화 방식에 독립적이다.
5. 가중치 로딩: w13과 w2의 분리
MoE의 가중치는 w13_weight(gate_proj + up_proj 융합)와 w2_weight(down_proj)로 구분된다. TP 샤딩 시 w13은 출력 차원, w2는 입력 차원을 기준으로 분할한다.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
Triton 커널 사용 시에는 가중치가 전치(transpose)되어 저장된다.
if not self.use_presharded_weights:
if not is_bias and self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1)
6. CUDA Graph 지원
피스와이즈 CUDA 그래프 모드에서는 별도의 최적화 경로를 사용한다.
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if is_in_piecewise_cuda_graph():
if TopKOutputChecker.format_is_standard(topk_output):
return moe_forward_piecewise_cuda_graph_impl(...)
설계 근거
| 설계 결정 | 이유 |
|---|---|
| Dispatch/Core/Combine 3단계 분리 | 분산 통신(Dispatch/Combine)과 연산(Core)을 독립적으로 최적화 |
| quant_method 위임 | FP16/FP8/FP4 등 양자화별 커널을 플러그인 형태로 교체 |
| w13 융합 | gate_proj와 up_proj를 하나의 GEMM으로 실행하여 커널 호출 절반 감소 |
| Triton 전치 저장 | Triton 커널의 메모리 접근 패턴에 최적화 |
관련 포스트
- SGLang CUTLASS MoE - CUTLASS 기반 전문가 연산
- SGLang MoE 라우팅 - Top-K 게이트 선택 알고리즘
- SGLang EP-MoE - 분산 전문가 병렬화
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] 하드웨어별 양자화 튜닝: B200, H100, MI300X 최적 설정
- 현재글 : [SGLang] Fused MoE (Triton): 라우팅과 전문가 연산의 융합
- 다음글 [SGLang] CUTLASS MoE: 최적화 GEMM 커널 기반 전문가 연산
댓글