본문으로 건너뛰기

[SGLang] FlashInfer + TensorRT-LLM MoE: 하이브리드 MoE 커널

들어가며

SGLang은 Triton과 CUTLASS 외에도 FlashInfer와 TensorRT-LLM(TRT-LLM)의 MoE 커널을 지원한다. FlashInfer의 trtllm_fp8_block_scale_moe는 TRT-LLM의 커널을 FlashInfer 인터페이스로 래핑한 것이고, FlashInfer의 CuteDSL 커널은 NVIDIA의 CuTe DSL을 활용한 FP4 전용 경로다. 이 두 가지는 각각 다른 시나리오에서 최적 성능을 제공한다.

관련 소스 경로:

구조도

                    MoE 커널 백엔드 선택
                          │
          ┌───────────────┼────────────────┐
          ▼               ▼                ▼
   FlashInfer         FlashInfer      FlashInfer
   TRT-LLM MoE       TRT-LLM        CuteDSL MoE
   (FP8 Block)       Routed MoE       (FP4)
        │            (사전 라우팅)         │
        ▼                ▼                ▼
  routing_logits    topk_ids          masked_m
  + hidden_states   + hidden_states   + hidden_states
        │                │                │
        ▼                ▼                ▼
  [라우팅+GEMM       [GEMM만 실행]     [Grouped GEMM
   단일 커널]                           + SiLU + GEMM]
        │                │                │
        ▼                ▼                ▼
     output           output           output

핵심 코드 분석

1. TRT-LLM FP8 Block Scale MoE: 라우팅 통합

trtllm_fp8_block_scale_moe_wrapper는 라우팅 logits를 직접 받아 라우팅과 전문가 GEMM을 하나의 커널에서 수행한다.

@register_custom_op(fake_impl=_fake_fp8_block_scale_moe)
def trtllm_fp8_block_scale_moe_wrapper(
    routing_logits: torch.Tensor,
    routing_bias: Optional[torch.Tensor],
    hidden_states: torch.Tensor,
    hidden_states_scale: torch.Tensor,
    gemm1_weights: torch.Tensor,
    gemm1_weights_scale: torch.Tensor,
    gemm2_weights: torch.Tensor,
    gemm2_weights_scale: torch.Tensor,
    num_experts: int,
    top_k: int,
    n_group: Optional[int],
    topk_group: Optional[int],
    routed_scaling_factor: Optional[float],
    routing_method_type: int = 0,
    ...
) -> torch.Tensor:
    from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
    return trtllm_fp8_block_scale_moe(**kwargs)

핵심 파라미터를 살펴보면:

  • routing_logits: 라우터의 raw logits (커널 내부에서 softmax/sigmoid 처리)
  • n_group, topk_group: DeepSeek 스타일 grouped Top-K 지원
  • routing_method_type: 라우팅 알고리즘 유형 (softmax=0 등)
  • routed_scaling_factor: 라우팅 가중치 스케일링

2. TRT-LLM Routed MoE: 사전 라우팅 경로

이미 라우팅이 완료된 경우를 위한 별도 래퍼도 제공한다.

@register_custom_op(fake_impl=_fake_fp8_block_scale_routed_moe)
def trtllm_fp8_block_scale_routed_moe_wrapper(
    topk_ids: torch.Tensor,        # routing_logits 대신 topk_ids
    routing_bias: Optional[torch.Tensor],
    hidden_states: torch.Tensor,
    hidden_states_scale: torch.Tensor,
    ...
) -> torch.Tensor:
    from flashinfer.fused_moe import trtllm_fp8_block_scale_routed_moe
    return trtllm_fp8_block_scale_routed_moe(**kwargs)

routing_logits 대신 topk_ids를 받는다. EP 환경에서 디스패처가 이미 라우팅을 수행한 경우에 사용된다.

3. Per-Tensor Scale MoE: 간소화된 양자화

블록 스케일 대신 텐서 단위 스케일을 사용하는 경로도 있다.

def trtllm_fp8_per_tensor_scale_moe_wrapper(
    routing_logits: torch.Tensor,
    hidden_states: torch.Tensor,
    gemm1_weights: torch.Tensor,
    output1_scales_scalar: torch.Tensor,
    output1_scales_gate_scalar: torch.Tensor,
    gemm2_weights: torch.Tensor,
    output2_scales_scalar: torch.Tensor,
    use_routing_scales_on_input: bool,
    ...
):

output1_scales_scalaroutput1_scales_gate_scalar가 gate/up projection의 출력 스케일을 개별적으로 관리한다.

4. CUDA Graph 호환: register_custom_op

모든 래퍼 함수가 @register_custom_op으로 데코레이트되어 있다. 이는 torch.compile 및 CUDA Graph와의 호환성을 위한 것이다.

def _fake_fp8_block_scale_moe(routing_logits, routing_bias,
                               hidden_states, ...):
    return torch.empty(
        hidden_states.shape, dtype=torch.bfloat16,
        device=hidden_states.device
    )

@register_custom_op(fake_impl=_fake_fp8_block_scale_moe)
def trtllm_fp8_block_scale_moe_wrapper(...):
    ...

fake_impl은 shape 추론용으로, 실제 커널 없이도 컴파일 그래프를 구성할 수 있게 한다.

5. FlashInfer CuteDSL MoE: FP4 마스크 기반

flashinfer_cutedsl_moe_masked는 NVIDIA CuTe DSL로 구현된 FP4 전용 MoE 커널이다. DeepEP Low-Latency 모드의 마스크 기반 입력을 처리한다.

def flashinfer_cutedsl_moe_masked(
    hidden_states: tuple[torch.Tensor, Optional[torch.Tensor]],
    input_global_scale, w1, w1_blockscale, w1_alpha,
    w2, a2_global_scale, w2_blockscale, w2_alpha,
    masked_m, ...):

    # FP4 양자화 (필요시)
    if hidden_states[1] is not None:
        a_q = hidden_states[0].view(torch.uint8)
        a_q_sf = hidden_states[1].view(torch.float8_e4m3fn)
    else:
        a_q, a_q_sf = scaled_fp4_grouped_quantize(
            hidden_states[0], masked_m, input_global_scale,
        )

hidden_states는 튜플로, 이미 양자화된 경우 (fp4_data, scale), 아닌 경우 (bf16_data, None)을 받는다.

6. CuteDSL Grouped GEMM

FP4 GEMM은 grouped_gemm_nt_masked로 수행된다.

# Gemm1: up + gate projection
grouped_gemm_nt_masked(
    (a_q, a_q_sf),
    (w1.permute(1, 2, 0), w1_blockscale),
    gateup_output,
    masked_m,
    ab_dtype="float4_e2m1fn",
    sf_dtype="float8_e4m3fn",
    c_dtype="bfloat16",
    sf_vec_size=16,
    alpha=w1_alpha.view(1, 1, num_experts),
)

# SiLU + FP4 재양자화
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
    gateup_output.permute(2, 0, 1), masked_m, a2_global_scale,
)

# Gemm2: down projection
grouped_gemm_nt_masked(
    (diq, diq_sf),
    (w2.permute(1, 2, 0), w2_blockscale),
    out, masked_m, ...
)

masked_m 텐서는 각 전문가에 할당된 실제 토큰 수를 나타내며, 유효하지 않은 행의 연산을 건너뛴다.

커널별 비교

항목 TRT-LLM MoE CuteDSL MoE Triton MoE
양자화 FP8 block/per-tensor FP4 (E2M1) FP16/BF16
라우팅 통합 지원 (내부 softmax) 미지원 (외부) 미지원
입력 형태 Dense 마스크 기반 Dense
EP 모드 Normal/LL LL 전용 Normal
HW 요구 SM80+ SM90+ SM80+

설계 근거

TRT-LLM 커널이 라우팅을 내부에 포함하는 이유는 커널 호출 오버헤드를 줄이고, 라우팅 결과를 중간 버퍼 없이 즉시 사용하기 위해서다. CuteDSL 커널은 Blackwell의 FP4 하드웨어 지원을 활용하여 메모리 대역폭을 절반으로 줄인다. 두 커널 모두 register_custom_op으로 CUDA Graph에 통합되어 추론 루프의 일부로 캡처된다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글