[axolotl] ScatterMoE 커널 라우팅 통합: Softmax/Sigmoid 기반 라우팅과 Autotune Telemetry 추가
PR 링크: axolotl-ai-cloud/axolotl#3475 상태: Merged | 변경: +1988 / -35
들어가며
Mixture-of-Experts(MoE) 모델에서 라우팅(routing)은 각 토큰을 어떤 expert에 할당할지 결정하는 핵심 로직입니다. 모델 아키텍처마다 라우팅 전략이 다른데(Qwen은 softmax+topk, DeepSeek V3는 sigmoid+topk), 기존에는 이 로직이 분산되어 있었습니다. 이 PR은 라우팅 전략을 통합 함수로 정리하고, Triton 커널의 autotune 결과를 자동으로 수집하는 telemetry 시스템을 추가합니다.
핵심 코드 분석
1. 라우팅 전략 통합 함수
두 가지 주요 라우팅 패턴을 독립 함수로 분리했습니다:
def _softmax_topk_route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
"""Softmax->topk routing (Qwen, OLMoE, Mixtral, MiniMax)."""
router_logits = F.linear(hidden_states, gate_weight)
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(hidden_states, gate_lora_delta)
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if getattr(base_gate, "norm_topk_prob", True):
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
return routing_weights, selected_experts, top_k, num_experts
def _sigmoid_topk_route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
"""Sigmoid->topk routing (GLM, DeepSeek V3, MiniMax M2)."""
router_logits = F.linear(hidden_states.float(), gate_weight.float())
router_probs = router_logits.sigmoid()
# e_score_correction_bias, group-based selection, routed_scaling_factor 지원
...
_sigmoid_topk_route는 DeepSeek V3의 e_score_correction_bias, 그룹 기반 expert 선택, routed_scaling_factor 등 복잡한 라우팅 로직도 지원합니다.
2. Autotune Telemetry Callback
Triton의 @triton.autotune이 선택한 커널 설정을 자동으로 수집하는 callback:
class AutotuneReportCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
if self._reported:
return
configs = collect_autotune_configs()
if not configs:
if state.global_step >= _MAX_POLL_STEP:
self._reported = True
return
self._reported = True
# telemetry로 전송
학습 시작 후 최대 5 step까지만 autotune 데이터를 확인하고, 이후에는 _reported 플래그로 hot-path 비용을 0으로 만듭니다.
왜 이게 좋은가
라우팅 로직의 통합은 코드 중복 제거와 일관성 확보 두 가지 이점을 제공합니다. 새로운 MoE 모델 지원 시 기존 라우팅 함수를 재사용할 수 있으며, 버그 수정이 모든 모델에 일괄 적용됩니다. Autotune telemetry는 사용자 환경에서 실제로 선택된 커널 설정을 수집하여, 향후 기본값 최적화에 활용할 수 있는 데이터를 제공합니다. 특히 _MAX_POLL_STEP과 _reported 플래그를 통해 학습 성능에 영향을 주지 않도록 설계한 점이 인상적입니다.
정리
| 항목 | 내용 |
|---|---|
| 라우팅 | softmax_topk, sigmoid_topk 2개 통합 함수 |
| Telemetry | AutotuneReportCallback + autotune_collector |
| 지원 모델 | Qwen, OLMoE, Mixtral, DeepSeek V3, GLM, MiniMax 등 |
참고 자료
알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [axolotl] 코드 품질 개선: CONTRIBUTING.md 플레이스홀더 수정, bare except 제거, convert.py 테스트 추가
- 현재글 : [axolotl] ScatterMoE 커널 라우팅 통합: Softmax/Sigmoid 기반 라우팅과 Autotune Telemetry 추가
- 다음글 [Open WebUI] Artifacts 컴포넌트 메모리 누수 수정
댓글