[SGLang] CUTLASS MoE: 최적화 GEMM 커널 기반 전문가 연산
들어가며
SGLang의 CUTLASS MoE는 NVIDIA CUTLASS 라이브러리를 기반으로 FP8/FP4 양자화된 MoE GEMM을 수행한다. Triton 기반 구현과 달리 CUTLASS는 NVIDIA GPU의 하드웨어 특성에 맞춰 수작업으로 튜닝된 커널을 제공하며, SM90(Hopper)과 SM100(Blackwell)에서 Expert Specialization(ES) 커널을 지원한다.
소스 경로: python/sglang/srt/layers/moe/cutlass_moe.py
구조도
입력: a (m, k) ──► prepare_moe_input ──► 토큰-전문가 매핑
│
┌─────────────────┘
▼
FP8 양자화 (per_token_group_quant)
│
▼
┌─── GEMM1: a_q × w1_q ───┐
│ (m*topk, k) × (E,k,2n) │
└──────────┬───────────────┘
▼
SiLU & Mul
│
▼
┌─── GEMM2: inter × w2_q ──┐
│ (m*topk, n) × (E,n,k) │
└──────────┬────────────────┘
▼
apply_shuffle_mul_sum
│
▼
출력: (m, k)
핵심 코드 분석
1. prepare_moe_input: 토큰-전문가 매핑 준비
cutlass_fused_experts_fp8 함수의 첫 단계에서 prepare_moe_input을 호출하여 토큰을 전문가별로 재배열하기 위한 매핑을 생성한다.
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
prepare_moe_input(
topk_ids, expert_offsets, problem_sizes1, problem_sizes2,
a_map, c_map, num_experts, n, k, blockscale_offsets,
)
a_map은 입력 토큰의 재배열 인덱스, c_map은 출력을 원래 순서로 복원하는 인덱스다. expert_offsets는 각 전문가에 할당된 토큰의 시작 위치를 나타낸다.
2. FP8 양자화와 행 셔플
입력 활성화를 FP8로 양자화한 후, shuffle_rows로 전문가별 순서에 맞게 재배열한다.
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
그룹 크기 128은 FP8 블록 스케일링의 표준 설정이다. 각 128개 원소 블록마다 독립적인 스케일 팩터를 가진다.
3. Expert Specialization(ES) 커널 분기
SM90 이상에서는 Expert Specialization 커널을 사용할 수 있다. ES 커널은 전문가별로 워프를 특화시켜 로드 밸런싱을 개선한다.
if is_sm90_supported() and es_up:
es_fp8_blockwise_scaled_grouped_mm(
c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
a1_strides, a1_strides, c1_strides,
problem_sizes1, expert_offsets[:-1], workspace,
)
elif use_mxfp8 and es_up:
es_sm100_mxfp8_blockscaled_grouped_mm(...)
else:
fp8_blockwise_scaled_grouped_mm(...)
세 가지 경로가 존재한다: (1) SM90 ES 커널, (2) SM100 MXFP8 ES 커널, (3) 범용 그룹드 GEMM.
4. SiLU 활성화와 두 번째 GEMM
첫 번째 GEMM의 출력은 gate+up의 융합이므로 silu_and_mul로 활성화를 적용한다.
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
silu_and_mul(c1, intermediate)
c1의 shape은 (m*topk, 2n)이다. silu_and_mul은 전반부에 SiLU를 적용하고 후반부와 element-wise 곱을 수행하여 (m*topk, n)을 출력한다.
5. 결과 합산: apply_shuffle_mul_sum
최종 단계에서 전문가 출력을 원래 토큰 순서로 복원하고, 라우팅 가중치를 곱하여 합산한다.
apply_shuffle_mul_sum(c2, output, c_map, topk_weights.to(out_dtype))
이 단일 커널 호출이 셔플, 가중치 곱셈, 합산을 모두 수행한다.
6. FP4 지원: cutlass_moe_fp4
FP4(E2M1) 양자화를 위한 별도 경로도 제공한다.
rep_a_fp4, rep_a_blockscale = scaled_fp4_experts_quant(
a, a1_gscale, params.expert_offsets, params.blockscale_offsets,
num_topk, expert_map=a_map,
)
c1 = cutlass_fp4_group_mm(
rep_a_fp4, w1_fp4, rep_a_blockscale, w1_blockscale,
w1_alphas, out_dtype, params.to_gemm1_args(),
)
FP4는 블록 크기 16으로 스케일링하며, FLOAT4_E2M1_MAX = 6.0의 제한된 범위를 alpha 스케일로 보정한다.
Triton 대비 비교
| 항목 | CUTLASS MoE | Triton MoE |
|---|---|---|
| 커널 작성 방식 | C++ 템플릿, 사전 컴파일 | Python DSL, JIT 컴파일 |
| SM90/SM100 최적화 | ES 커널 네이티브 지원 | 범용 타일링 |
| FP4 지원 | CUTLASS FP4 Group MM | 미지원 |
| MXFP8 | SM100 전용 경로 | 미지원 |
| 유연성 | 하드웨어별 튜닝 필요 | 자동 튜닝 |
설계 근거
CUTLASS 경로는 최신 NVIDIA 하드웨어(Hopper, Blackwell)에서 최대 성능을 추출하기 위해 존재한다. Expert Specialization은 전문가 간 토큰 수 불균형 시 유휴 SM을 줄여주고, MXFP8은 Blackwell의 네이티브 MX 포맷을 활용한다. 범용성보다 성능을 우선하는 선택이다.
관련 포스트
- SGLang Fused MoE (Triton) - Triton 기반 MoE 레이어
- SGLang FlashInfer MoE - FlashInfer/TRT-LLM 하이브리드 커널
- SGLang MoE 라우팅 - Top-K 게이트 선택
참고
관련 포스트
- [vllm] vLLM, MXFP4 양자화 MoE 모델을 위한 CUTLASS 기반 SM100 커널 추가로 성능 향상
- [triton] NVIDIA canSkipBarSync 복원으로 MoE 커널 18GBps 성능 향상
- [논문리뷰] On the Scaling of PEFT: Towards Million Personal Models of Trillion Parameters
- [논문리뷰] NITP: Next Implicit Token Prediction for LLM Pre-training
- [논문리뷰] Confidence-Adaptive SwiGLU for Mixture-of-Experts
SGLang 의 다른글
- 이전글 [SGLang] Fused MoE (Triton): 라우팅과 전문가 연산의 융합
- 현재글 : [SGLang] CUTLASS MoE: 최적화 GEMM 커널 기반 전문가 연산
- 다음글 [SGLang] Expert Parallel MoE: 분산 전문가 레이어 구현
댓글