본문으로 건너뛰기

[SGLang] MoE 라우팅: 토큰에서 전문가로의 배분 알고리즘

들어가며

MoE 모델에서 라우팅은 각 토큰을 어떤 전문가에 배분할지 결정하는 핵심 알고리즘이다. SGLang은 softmax Top-K, grouped Top-K(DeepSeek 방식), sigmoid 기반 biased Top-K, 그리고 Triton 커널로 구현된 fused router까지 다양한 라우팅 전략을 지원한다.

관련 소스 경로:

구조도

router_logits (m, num_experts)
         │
         ├─► softmax Top-K ────────► topk_weights, topk_ids
         │
         ├─► grouped Top-K ────────► 그룹 선택 → 그룹 내 Top-K
         │   (DeepSeek V2/V3)
         │
         ├─► biased grouped Top-K ─► sigmoid + correction_bias
         │   (DeepSeek V3)
         │
         ├─► Fused Router ─────────► 라우팅 + MatMul을 하나의 커널로
         │   (Triton CUDACore / TensorCore)
         │
         └─► Triton Kernels routing ► RoutingData, GatherIndx, ScatterIndx

핵심 코드 분석

1. TopK 클래스: 다중 플랫폼 지원

TopKMultiPlatformOp을 상속하여 CUDA, CPU, NPU에서 각각 최적화된 경로를 제공한다.

class TopK(MultiPlatformOp):
    def __init__(self, top_k, *, use_grouped_topk=False,
                 topk_group=None, num_expert_group=None,
                 renormalize=True, scoring_func="softmax", ...):
        self.topk_config = TopKConfig(
            top_k=top_k, use_grouped_topk=use_grouped_topk,
            renormalize=renormalize, scoring_func=scoring_func, ...
        )

TopKConfig는 라우팅에 필요한 모든 파라미터를 담는 데이터클래스다.

2. 기본 softmax Top-K

가장 기본적인 라우팅. sgl_kerneltopk_softmax CUDA 커널을 사용한다.

def fused_topk(hidden_states, gating_output, topk, renormalize, ...):
    topk_weights = torch.empty(M, topk, dtype=torch.float32, device=device)
    topk_ids = torch.empty(M, topk, dtype=torch.int32, device=device)

    if scoring_func == "softmax":
        topk_softmax(topk_weights, topk_ids, gating_output, renormalize)
    elif scoring_func == "sigmoid":
        topk_sigmoid(topk_weights, topk_ids, gating_output,
                     renormalize, correction_bias)

topk_softmax는 softmax 계산과 Top-K 선택을 단일 커널에서 수행한다. renormalize=True이면 선택된 K개 전문가의 가중치 합이 1이 되도록 정규화한다.

3. Grouped Top-K: DeepSeek 방식

DeepSeek V2/V3에서 사용하는 그룹 기반 라우팅. 전문가를 그룹으로 나누고, 먼저 상위 그룹을 선택한 후 그 안에서 Top-K를 수행한다.

def grouped_topk_gpu(hidden_states, gating_output, topk, renormalize,
                     num_expert_group=None, topk_group=None, ...):
    scores = torch.softmax(gating_output, dim=-1)
    group_scores = (
        scores.view(num_token, num_expert_group, -1).max(dim=-1).values
    )  # [n, n_group]
    group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
    group_mask = torch.zeros_like(group_scores)
    group_mask.scatter_(1, group_idx, 1)

예를 들어 256개 전문가가 8개 그룹(각 32개)으로 나뉘고, topk_group=4이면 먼저 상위 4개 그룹을 선택한다. 그 후 선택된 그룹 내에서만 Top-K를 수행한다.

    score_mask = group_mask.unsqueeze(-1).expand(...).reshape(num_token, -1)
    tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)
    topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, ...)

4. Biased Grouped Top-K: correction_bias 적용

DeepSeek V3는 sigmoid 스코어링과 correction bias를 함께 사용한다.

def biased_grouped_topk_impl(hidden_states, gating_output,
                              correction_bias, topk, renormalize, ...):
    scores = gating_output.sigmoid()
    scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
    group_scores = (
        scores_for_choice.view(num_token, num_expert_group, -1)
        .topk(2, dim=-1)[0].sum(dim=-1)
    )

correction_bias는 학습된 보정치로, 특정 전문가의 선택 확률을 조정한다. 그룹 스코어 계산 시 각 그룹의 상위 2개 값의 합을 사용하는 것이 특징이다.

5. Fused MoE Router: 라우팅 MatMul 융합

router.pyFusedMoeRouter는 라우터의 linear projection과 Top-K 선택을 하나의 Triton 커널로 융합한다.

@triton.jit
def fused_moe_router_cudacore_kernel(
    input_ptr, moe_router_weight_ptr,
    topk_weights_ptr, topk_ids_ptr, ...):
    # MatMul: input × router_weight
    logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)

    # Softcapping
    if moe_softcapping != 0:
        logits_scaled = logits / moe_softcapping
        logits_softcapped = tl.tanh(logits_scaled) * moe_softcapping

    # Top-1
    top1 = tl.argmax(logits_softcapped, axis=0)
    top1_v = tl.max(logits_softcapped, axis=0)
    invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0)

배치 크기에 따라 CUDACore(소규모)와 TensorCore(대규모) 경로로 분기한다.

def fused_moe_router_shim(...):
    if (bs >= 512 or num_experts > 8) and hidden_dim % BLOCK_SIZE_K == 0:
        return fused_moe_router_tensorcore(...)
    else:
        return fused_moe_router_cudacore(...)

6. 공유 전문가 슬롯 할당

Grouped Top-K에서 공유 전문가가 있으면, 마지막 Top-K 슬롯을 공유 전문가에 랜덤 할당한다.

if num_fused_shared_experts:
    topk_ids[:, -1] = torch.randint(
        low=num_experts,
        high=num_experts + num_fused_shared_experts,
        size=(topk_ids.size(0),), dtype=topk_ids.dtype, device=topk_ids.device,
    )
    if routed_scaling_factor is not None:
        topk_weights[:, -1] = (
            topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
        )

라우팅 전략 비교

전략 모델 예시 특징
softmax Top-K Mixtral 가장 기본, 부드러운 확률 분포
grouped Top-K DeepSeek V2 그룹 간 밸런싱, 전문가 특화 촉진
biased sigmoid DeepSeek V3 correction bias로 부하 조정
Fused Router Gemma2 등 라우터 MatMul + Top-K 융합

출력 포맷

class TopKOutputFormat(IntEnum):
    STANDARD = auto()      # topk_weights, topk_ids, router_logits
    TRITON_KERNEL = auto() # RoutingData, GatherIndx, ScatterIndx
    BYPASSED = auto()      # hidden_states + config (FlashInfer용)

BYPASSED 포맷은 FlashInfer TRT-LLM처럼 라우팅과 전문가 연산을 하나의 커널로 처리하는 경우에 사용된다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글