[SGLang] MoE 라우팅: 토큰에서 전문가로의 배분 알고리즘
들어가며
MoE 모델에서 라우팅은 각 토큰을 어떤 전문가에 배분할지 결정하는 핵심 알고리즘이다. SGLang은 softmax Top-K, grouped Top-K(DeepSeek 방식), sigmoid 기반 biased Top-K, 그리고 Triton 커널로 구현된 fused router까지 다양한 라우팅 전략을 지원한다.
관련 소스 경로:
구조도
router_logits (m, num_experts)
│
├─► softmax Top-K ────────► topk_weights, topk_ids
│
├─► grouped Top-K ────────► 그룹 선택 → 그룹 내 Top-K
│ (DeepSeek V2/V3)
│
├─► biased grouped Top-K ─► sigmoid + correction_bias
│ (DeepSeek V3)
│
├─► Fused Router ─────────► 라우팅 + MatMul을 하나의 커널로
│ (Triton CUDACore / TensorCore)
│
└─► Triton Kernels routing ► RoutingData, GatherIndx, ScatterIndx
핵심 코드 분석
1. TopK 클래스: 다중 플랫폼 지원
TopK는 MultiPlatformOp을 상속하여 CUDA, CPU, NPU에서 각각 최적화된 경로를 제공한다.
class TopK(MultiPlatformOp):
def __init__(self, top_k, *, use_grouped_topk=False,
topk_group=None, num_expert_group=None,
renormalize=True, scoring_func="softmax", ...):
self.topk_config = TopKConfig(
top_k=top_k, use_grouped_topk=use_grouped_topk,
renormalize=renormalize, scoring_func=scoring_func, ...
)
TopKConfig는 라우팅에 필요한 모든 파라미터를 담는 데이터클래스다.
2. 기본 softmax Top-K
가장 기본적인 라우팅. sgl_kernel의 topk_softmax CUDA 커널을 사용한다.
def fused_topk(hidden_states, gating_output, topk, renormalize, ...):
topk_weights = torch.empty(M, topk, dtype=torch.float32, device=device)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=device)
if scoring_func == "softmax":
topk_softmax(topk_weights, topk_ids, gating_output, renormalize)
elif scoring_func == "sigmoid":
topk_sigmoid(topk_weights, topk_ids, gating_output,
renormalize, correction_bias)
topk_softmax는 softmax 계산과 Top-K 선택을 단일 커널에서 수행한다. renormalize=True이면 선택된 K개 전문가의 가중치 합이 1이 되도록 정규화한다.
3. Grouped Top-K: DeepSeek 방식
DeepSeek V2/V3에서 사용하는 그룹 기반 라우팅. 전문가를 그룹으로 나누고, 먼저 상위 그룹을 선택한 후 그 안에서 Top-K를 수행한다.
def grouped_topk_gpu(hidden_states, gating_output, topk, renormalize,
num_expert_group=None, topk_group=None, ...):
scores = torch.softmax(gating_output, dim=-1)
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
예를 들어 256개 전문가가 8개 그룹(각 32개)으로 나뉘고, topk_group=4이면 먼저 상위 4개 그룹을 선택한다. 그 후 선택된 그룹 내에서만 Top-K를 수행한다.
score_mask = group_mask.unsqueeze(-1).expand(...).reshape(num_token, -1)
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, ...)
4. Biased Grouped Top-K: correction_bias 적용
DeepSeek V3는 sigmoid 스코어링과 correction bias를 함께 사용한다.
def biased_grouped_topk_impl(hidden_states, gating_output,
correction_bias, topk, renormalize, ...):
scores = gating_output.sigmoid()
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(num_token, num_expert_group, -1)
.topk(2, dim=-1)[0].sum(dim=-1)
)
correction_bias는 학습된 보정치로, 특정 전문가의 선택 확률을 조정한다. 그룹 스코어 계산 시 각 그룹의 상위 2개 값의 합을 사용하는 것이 특징이다.
5. Fused MoE Router: 라우팅 MatMul 융합
router.py의 FusedMoeRouter는 라우터의 linear projection과 Top-K 선택을 하나의 Triton 커널로 융합한다.
@triton.jit
def fused_moe_router_cudacore_kernel(
input_ptr, moe_router_weight_ptr,
topk_weights_ptr, topk_ids_ptr, ...):
# MatMul: input × router_weight
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
# Softcapping
if moe_softcapping != 0:
logits_scaled = logits / moe_softcapping
logits_softcapped = tl.tanh(logits_scaled) * moe_softcapping
# Top-1
top1 = tl.argmax(logits_softcapped, axis=0)
top1_v = tl.max(logits_softcapped, axis=0)
invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0)
배치 크기에 따라 CUDACore(소규모)와 TensorCore(대규모) 경로로 분기한다.
def fused_moe_router_shim(...):
if (bs >= 512 or num_experts > 8) and hidden_dim % BLOCK_SIZE_K == 0:
return fused_moe_router_tensorcore(...)
else:
return fused_moe_router_cudacore(...)
6. 공유 전문가 슬롯 할당
Grouped Top-K에서 공유 전문가가 있으면, 마지막 Top-K 슬롯을 공유 전문가에 랜덤 할당한다.
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),), dtype=topk_ids.dtype, device=topk_ids.device,
)
if routed_scaling_factor is not None:
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
)
라우팅 전략 비교
| 전략 | 모델 예시 | 특징 |
|---|---|---|
| softmax Top-K | Mixtral | 가장 기본, 부드러운 확률 분포 |
| grouped Top-K | DeepSeek V2 | 그룹 간 밸런싱, 전문가 특화 촉진 |
| biased sigmoid | DeepSeek V3 | correction bias로 부하 조정 |
| Fused Router | Gemma2 등 | 라우터 MatMul + Top-K 융합 |
출력 포맷
class TopKOutputFormat(IntEnum):
STANDARD = auto() # topk_weights, topk_ids, router_logits
TRITON_KERNEL = auto() # RoutingData, GatherIndx, ScatterIndx
BYPASSED = auto() # hidden_states + config (FlashInfer용)
BYPASSED 포맷은 FlashInfer TRT-LLM처럼 라우팅과 전문가 연산을 하나의 커널로 처리하는 경우에 사용된다.
관련 포스트
- SGLang Fused MoE (Triton) - MoE 전체 레이어 구현
- SGLang EP-MoE - 분산 환경에서의 라우팅
- SGLang EPLB - 부하 균형 알고리즘
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Expert Parallel MoE: 분산 전문가 레이어 구현
- 현재글 : [SGLang] MoE 라우팅: 토큰에서 전문가로의 배분 알고리즘
- 다음글 [SGLang] Elastic Expert Parallelism: 동적 전문가 스케일링
댓글