[SGLang] FlashInfer + TensorRT-LLM MoE: 하이브리드 MoE 커널
들어가며
SGLang은 Triton과 CUTLASS 외에도 FlashInfer와 TensorRT-LLM(TRT-LLM)의 MoE 커널을 지원한다. FlashInfer의 trtllm_fp8_block_scale_moe는 TRT-LLM의 커널을 FlashInfer 인터페이스로 래핑한 것이고, FlashInfer의 CuteDSL 커널은 NVIDIA의 CuTe DSL을 활용한 FP4 전용 경로다. 이 두 가지는 각각 다른 시나리오에서 최적 성능을 제공한다.
관련 소스 경로:
python/sglang/srt/layers/moe/flashinfer_trtllm_moe.pypython/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
구조도
MoE 커널 백엔드 선택
│
┌───────────────┼────────────────┐
▼ ▼ ▼
FlashInfer FlashInfer FlashInfer
TRT-LLM MoE TRT-LLM CuteDSL MoE
(FP8 Block) Routed MoE (FP4)
│ (사전 라우팅) │
▼ ▼ ▼
routing_logits topk_ids masked_m
+ hidden_states + hidden_states + hidden_states
│ │ │
▼ ▼ ▼
[라우팅+GEMM [GEMM만 실행] [Grouped GEMM
단일 커널] + SiLU + GEMM]
│ │ │
▼ ▼ ▼
output output output
핵심 코드 분석
1. TRT-LLM FP8 Block Scale MoE: 라우팅 통합
trtllm_fp8_block_scale_moe_wrapper는 라우팅 logits를 직접 받아 라우팅과 전문가 GEMM을 하나의 커널에서 수행한다.
@register_custom_op(fake_impl=_fake_fp8_block_scale_moe)
def trtllm_fp8_block_scale_moe_wrapper(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
hidden_states_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
num_experts: int,
top_k: int,
n_group: Optional[int],
topk_group: Optional[int],
routed_scaling_factor: Optional[float],
routing_method_type: int = 0,
...
) -> torch.Tensor:
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
return trtllm_fp8_block_scale_moe(**kwargs)
핵심 파라미터를 살펴보면:
routing_logits: 라우터의 raw logits (커널 내부에서 softmax/sigmoid 처리)n_group,topk_group: DeepSeek 스타일 grouped Top-K 지원routing_method_type: 라우팅 알고리즘 유형 (softmax=0 등)routed_scaling_factor: 라우팅 가중치 스케일링
2. TRT-LLM Routed MoE: 사전 라우팅 경로
이미 라우팅이 완료된 경우를 위한 별도 래퍼도 제공한다.
@register_custom_op(fake_impl=_fake_fp8_block_scale_routed_moe)
def trtllm_fp8_block_scale_routed_moe_wrapper(
topk_ids: torch.Tensor, # routing_logits 대신 topk_ids
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
hidden_states_scale: torch.Tensor,
...
) -> torch.Tensor:
from flashinfer.fused_moe import trtllm_fp8_block_scale_routed_moe
return trtllm_fp8_block_scale_routed_moe(**kwargs)
routing_logits 대신 topk_ids를 받는다. EP 환경에서 디스패처가 이미 라우팅을 수행한 경우에 사용된다.
3. Per-Tensor Scale MoE: 간소화된 양자화
블록 스케일 대신 텐서 단위 스케일을 사용하는 경로도 있다.
def trtllm_fp8_per_tensor_scale_moe_wrapper(
routing_logits: torch.Tensor,
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
gemm2_weights: torch.Tensor,
output2_scales_scalar: torch.Tensor,
use_routing_scales_on_input: bool,
...
):
output1_scales_scalar와 output1_scales_gate_scalar가 gate/up projection의 출력 스케일을 개별적으로 관리한다.
4. CUDA Graph 호환: register_custom_op
모든 래퍼 함수가 @register_custom_op으로 데코레이트되어 있다. 이는 torch.compile 및 CUDA Graph와의 호환성을 위한 것이다.
def _fake_fp8_block_scale_moe(routing_logits, routing_bias,
hidden_states, ...):
return torch.empty(
hidden_states.shape, dtype=torch.bfloat16,
device=hidden_states.device
)
@register_custom_op(fake_impl=_fake_fp8_block_scale_moe)
def trtllm_fp8_block_scale_moe_wrapper(...):
...
fake_impl은 shape 추론용으로, 실제 커널 없이도 컴파일 그래프를 구성할 수 있게 한다.
5. FlashInfer CuteDSL MoE: FP4 마스크 기반
flashinfer_cutedsl_moe_masked는 NVIDIA CuTe DSL로 구현된 FP4 전용 MoE 커널이다. DeepEP Low-Latency 모드의 마스크 기반 입력을 처리한다.
def flashinfer_cutedsl_moe_masked(
hidden_states: tuple[torch.Tensor, Optional[torch.Tensor]],
input_global_scale, w1, w1_blockscale, w1_alpha,
w2, a2_global_scale, w2_blockscale, w2_alpha,
masked_m, ...):
# FP4 양자화 (필요시)
if hidden_states[1] is not None:
a_q = hidden_states[0].view(torch.uint8)
a_q_sf = hidden_states[1].view(torch.float8_e4m3fn)
else:
a_q, a_q_sf = scaled_fp4_grouped_quantize(
hidden_states[0], masked_m, input_global_scale,
)
hidden_states는 튜플로, 이미 양자화된 경우 (fp4_data, scale), 아닌 경우 (bf16_data, None)을 받는다.
6. CuteDSL Grouped GEMM
FP4 GEMM은 grouped_gemm_nt_masked로 수행된다.
# Gemm1: up + gate projection
grouped_gemm_nt_masked(
(a_q, a_q_sf),
(w1.permute(1, 2, 0), w1_blockscale),
gateup_output,
masked_m,
ab_dtype="float4_e2m1fn",
sf_dtype="float8_e4m3fn",
c_dtype="bfloat16",
sf_vec_size=16,
alpha=w1_alpha.view(1, 1, num_experts),
)
# SiLU + FP4 재양자화
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
gateup_output.permute(2, 0, 1), masked_m, a2_global_scale,
)
# Gemm2: down projection
grouped_gemm_nt_masked(
(diq, diq_sf),
(w2.permute(1, 2, 0), w2_blockscale),
out, masked_m, ...
)
masked_m 텐서는 각 전문가에 할당된 실제 토큰 수를 나타내며, 유효하지 않은 행의 연산을 건너뛴다.
커널별 비교
| 항목 | TRT-LLM MoE | CuteDSL MoE | Triton MoE |
|---|---|---|---|
| 양자화 | FP8 block/per-tensor | FP4 (E2M1) | FP16/BF16 |
| 라우팅 통합 | 지원 (내부 softmax) | 미지원 (외부) | 미지원 |
| 입력 형태 | Dense | 마스크 기반 | Dense |
| EP 모드 | Normal/LL | LL 전용 | Normal |
| HW 요구 | SM80+ | SM90+ | SM80+ |
설계 근거
TRT-LLM 커널이 라우팅을 내부에 포함하는 이유는 커널 호출 오버헤드를 줄이고, 라우팅 결과를 중간 버퍼 없이 즉시 사용하기 위해서다. CuteDSL 커널은 Blackwell의 FP4 하드웨어 지원을 활용하여 메모리 대역폭을 절반으로 줄인다. 두 커널 모두 register_custom_op으로 CUDA Graph에 통합되어 추론 루프의 일부로 캡처된다.
관련 포스트
- SGLang CUTLASS MoE - CUTLASS 기반 전문가 연산
- SGLang Fused MoE (Triton) - Triton 기반 MoE 레이어
- SGLang MoE 라우팅 - 라우팅 알고리즘
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] EPLB: Expert-Parallel Load Balancing 알고리즘
- 현재글 : [SGLang] FlashInfer + TensorRT-LLM MoE: 하이브리드 MoE 커널
- 다음글 [SGLang] Speculative Decoding 개요: 원리와 구현 아키텍처
댓글