본문으로 건너뛰기

[SGLang] CUTLASS MoE: 최적화 GEMM 커널 기반 전문가 연산

들어가며

SGLang의 CUTLASS MoE는 NVIDIA CUTLASS 라이브러리를 기반으로 FP8/FP4 양자화된 MoE GEMM을 수행한다. Triton 기반 구현과 달리 CUTLASS는 NVIDIA GPU의 하드웨어 특성에 맞춰 수작업으로 튜닝된 커널을 제공하며, SM90(Hopper)과 SM100(Blackwell)에서 Expert Specialization(ES) 커널을 지원한다.

소스 경로: python/sglang/srt/layers/moe/cutlass_moe.py

구조도

입력: a (m, k)  ──► prepare_moe_input ──► 토큰-전문가 매핑
                                           
                         ┌─────────────────┘
                         
              FP8 양자화 (per_token_group_quant)
                         
                         
              ┌─── GEMM1: a_q × w1_q ───┐
                 (m*topk, k) × (E,k,2n) 
              └──────────┬───────────────┘
                         
                   SiLU & Mul
                         
                         
              ┌─── GEMM2: inter × w2_q ──┐
                 (m*topk, n) × (E,n,k)   
              └──────────┬────────────────┘
                         
              apply_shuffle_mul_sum
                         
                         
              출력: (m, k)

핵심 코드 분석

1. prepare_moe_input: 토큰-전문가 매핑 준비

cutlass_fused_experts_fp8 함수의 첫 단계에서 prepare_moe_input을 호출하여 토큰을 전문가별로 재배열하기 위한 매핑을 생성한다.

a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)

prepare_moe_input(
    topk_ids, expert_offsets, problem_sizes1, problem_sizes2,
    a_map, c_map, num_experts, n, k, blockscale_offsets,
)

a_map은 입력 토큰의 재배열 인덱스, c_map은 출력을 원래 순서로 복원하는 인덱스다. expert_offsets는 각 전문가에 할당된 토큰의 시작 위치를 나타낸다.

2. FP8 양자화와 행 셔플

입력 활성화를 FP8로 양자화한 후, shuffle_rows로 전문가별 순서에 맞게 재배열한다.

a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))

그룹 크기 128은 FP8 블록 스케일링의 표준 설정이다. 각 128개 원소 블록마다 독립적인 스케일 팩터를 가진다.

3. Expert Specialization(ES) 커널 분기

SM90 이상에서는 Expert Specialization 커널을 사용할 수 있다. ES 커널은 전문가별로 워프를 특화시켜 로드 밸런싱을 개선한다.

if is_sm90_supported() and es_up:
    es_fp8_blockwise_scaled_grouped_mm(
        c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
        a1_strides, a1_strides, c1_strides,
        problem_sizes1, expert_offsets[:-1], workspace,
    )
elif use_mxfp8 and es_up:
    es_sm100_mxfp8_blockscaled_grouped_mm(...)
else:
    fp8_blockwise_scaled_grouped_mm(...)

세 가지 경로가 존재한다: (1) SM90 ES 커널, (2) SM100 MXFP8 ES 커널, (3) 범용 그룹드 GEMM.

4. SiLU 활성화와 두 번째 GEMM

첫 번째 GEMM의 출력은 gate+up의 융합이므로 silu_and_mul로 활성화를 적용한다.

intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
silu_and_mul(c1, intermediate)

c1의 shape은 (m*topk, 2n)이다. silu_and_mul은 전반부에 SiLU를 적용하고 후반부와 element-wise 곱을 수행하여 (m*topk, n)을 출력한다.

5. 결과 합산: apply_shuffle_mul_sum

최종 단계에서 전문가 출력을 원래 토큰 순서로 복원하고, 라우팅 가중치를 곱하여 합산한다.

apply_shuffle_mul_sum(c2, output, c_map, topk_weights.to(out_dtype))

이 단일 커널 호출이 셔플, 가중치 곱셈, 합산을 모두 수행한다.

6. FP4 지원: cutlass_moe_fp4

FP4(E2M1) 양자화를 위한 별도 경로도 제공한다.

rep_a_fp4, rep_a_blockscale = scaled_fp4_experts_quant(
    a, a1_gscale, params.expert_offsets, params.blockscale_offsets,
    num_topk, expert_map=a_map,
)
c1 = cutlass_fp4_group_mm(
    rep_a_fp4, w1_fp4, rep_a_blockscale, w1_blockscale,
    w1_alphas, out_dtype, params.to_gemm1_args(),
)

FP4는 블록 크기 16으로 스케일링하며, FLOAT4_E2M1_MAX = 6.0의 제한된 범위를 alpha 스케일로 보정한다.

Triton 대비 비교

항목 CUTLASS MoE Triton MoE
커널 작성 방식 C++ 템플릿, 사전 컴파일 Python DSL, JIT 컴파일
SM90/SM100 최적화 ES 커널 네이티브 지원 범용 타일링
FP4 지원 CUTLASS FP4 Group MM 미지원
MXFP8 SM100 전용 경로 미지원
유연성 하드웨어별 튜닝 필요 자동 튜닝

설계 근거

CUTLASS 경로는 최신 NVIDIA 하드웨어(Hopper, Blackwell)에서 최대 성능을 추출하기 위해 존재한다. Expert Specialization은 전문가 간 토큰 수 불균형 시 유휴 SM을 줄여주고, MXFP8은 Blackwell의 네이티브 MX 포맷을 활용한다. 범용성보다 성능을 우선하는 선택이다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글