본문으로 건너뛰기

[axolotl] ScatterMoE 커널 라우팅 통합: Softmax/Sigmoid 기반 라우팅과 Autotune Telemetry 추가

PR 링크: axolotl-ai-cloud/axolotl#3475 상태: Merged | 변경: +1988 / -35

들어가며

Mixture-of-Experts(MoE) 모델에서 라우팅(routing)은 각 토큰을 어떤 expert에 할당할지 결정하는 핵심 로직입니다. 모델 아키텍처마다 라우팅 전략이 다른데(Qwen은 softmax+topk, DeepSeek V3는 sigmoid+topk), 기존에는 이 로직이 분산되어 있었습니다. 이 PR은 라우팅 전략을 통합 함수로 정리하고, Triton 커널의 autotune 결과를 자동으로 수집하는 telemetry 시스템을 추가합니다.

핵심 코드 분석

1. 라우팅 전략 통합 함수

두 가지 주요 라우팅 패턴을 독립 함수로 분리했습니다:

def _softmax_topk_route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
    """Softmax->topk routing (Qwen, OLMoE, Mixtral, MiniMax)."""
    router_logits = F.linear(hidden_states, gate_weight)
    if gate_lora_delta is not None:
        router_logits = router_logits + F.linear(hidden_states, gate_lora_delta)
    routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
    routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
    if getattr(base_gate, "norm_topk_prob", True):
        routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
    return routing_weights, selected_experts, top_k, num_experts

def _sigmoid_topk_route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
    """Sigmoid->topk routing (GLM, DeepSeek V3, MiniMax M2)."""
    router_logits = F.linear(hidden_states.float(), gate_weight.float())
    router_probs = router_logits.sigmoid()
    # e_score_correction_bias, group-based selection, routed_scaling_factor 지원
    ...

_sigmoid_topk_route는 DeepSeek V3의 e_score_correction_bias, 그룹 기반 expert 선택, routed_scaling_factor 등 복잡한 라우팅 로직도 지원합니다.

2. Autotune Telemetry Callback

Triton의 @triton.autotune이 선택한 커널 설정을 자동으로 수집하는 callback:

class AutotuneReportCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if self._reported:
            return
        configs = collect_autotune_configs()
        if not configs:
            if state.global_step >= _MAX_POLL_STEP:
                self._reported = True
            return
        self._reported = True
        # telemetry로 전송

학습 시작 후 최대 5 step까지만 autotune 데이터를 확인하고, 이후에는 _reported 플래그로 hot-path 비용을 0으로 만듭니다.

왜 이게 좋은가

라우팅 로직의 통합은 코드 중복 제거일관성 확보 두 가지 이점을 제공합니다. 새로운 MoE 모델 지원 시 기존 라우팅 함수를 재사용할 수 있으며, 버그 수정이 모든 모델에 일괄 적용됩니다. Autotune telemetry는 사용자 환경에서 실제로 선택된 커널 설정을 수집하여, 향후 기본값 최적화에 활용할 수 있는 데이터를 제공합니다. 특히 _MAX_POLL_STEP_reported 플래그를 통해 학습 성능에 영향을 주지 않도록 설계한 점이 인상적입니다.

정리

항목 내용
라우팅 softmax_topk, sigmoid_topk 2개 통합 함수
Telemetry AutotuneReportCallback + autotune_collector
지원 모델 Qwen, OLMoE, Mixtral, DeepSeek V3, GLM, MiniMax 등

참고 자료

알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글