[Axolotl] ScatterMoE LoRA 최적화: 벤치마크, 커널 분할, autograd 통합
PR 링크: axolotl-ai-cloud/axolotl#3513 상태: Merged | 변경: +1898 / -129
들어가며
ScatterMoE(Sparse Mixture of Experts)에 LoRA를 적용할 때, expert 수가 많은 모델(Qwen3.5: 256 experts)에서는 단일 fused 커널이 비효율적입니다. 이 PR은 모델 크기에 따라 fused와 split 전략을 자동 선택하는 로직, 종합 벤치마크 도구, 그리고 autograd 경로 최적화를 추가합니다.
핵심 코드 분석
1. 벤치마크 도구 추가
# benchmarks/bench_scattermoe_lora.py
BUILTIN_CONFIGS = {
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
"Qwen3-30B-A3B": (128, 2048, 768, 8),
"OLMoE-1B-7B": (64, 2048, 1024, 8),
"Mixtral-8x7B": (8, 4096, 14336, 2),
}
각 MoE 모델의 expert 수(E), hidden size(H), intermediate size(I), top-k를 미리 정의하고, forward/backward dX/backward dA-dB 각 커널의 성능을 개별 측정합니다. HuggingFace 모델 ID도 직접 입력할 수 있어 새 모델의 성능을 빠르게 프로파일링할 수 있습니다.
2. Fused vs Split 자동 선택
# Forward에서 expert가 적고 weight가 큰 경우 split 전략이 유리
dispatch = (
"split"
if (
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
)
else "fused"
)
Expert가 적고 weight 행렬이 큰 모델(Mixtral-8x7B)에서는 base scatter2scatter와 LoRA를 별도로 실행하는 split 전략이, expert가 많고 weight가 작은 모델(Qwen3.5)에서는 fused 전략이 유리합니다. 임계값 기반으로 자동 선택합니다.
3. Full autograd 벤치마크
def _run_autograd():
out = ScatterMoELoRA.apply(
x, W, k, sei, ssi, eo, lA, lB, 2.0,
None, None, False, False, True, False,
)
out.sum().backward()
개별 커널 성능뿐 아니라 forward + backward의 전체 autograd 경로 성능과 peak memory delta를 측정합니다.
왜 이게 좋은가
- 데이터 기반 의사결정: 벤치마크로 실측한 결과를 바탕으로 fused/split 임계값을 설정합니다. 추측이 아닌 측정입니다.
- 모델별 최적화: 256-expert 모델과 8-expert 모델은 완전히 다른 최적 전략을 가집니다. 자동 선택으로 사용자가 신경 쓸 필요가 없습니다.
- 재현 가능한 벤치마크:
bench_scattermoe_lora.py로 누구나 자신의 GPU에서 성능을 측정하고, PR의 성능 주장을 검증할 수 있습니다.
정리
ScatterMoE LoRA의 성능을 체계적으로 벤치마크하고, 모델 특성에 따른 자동 최적화 전략을 구현한 PR입니다. MoE + LoRA 조합의 실전 성능 최적화에 대한 좋은 참고 사례입니다.
참고 자료
이 포스트는 AI가 작성하였으며, 사실과 다를 수 있습니다. 정확한 정보는 원본 PR을 참고해 주세요.
관련 포스트
- [Axolotl] ScatterMoE LoRA Triton 커널의 autotune 탐색 공간 축소
- [Axolotl] LoRA 커널에 bias, dropout, DoRA, embedding 지원 추가
- [axolotl] ScatterMoE 커널 라우팅 통합: Softmax/Sigmoid 기반 라우팅과 Autotune Telemetry 추가
- [Axolotl] Qwen 3.5 모델 Liger 커널 지원 및 fused RMSNorm+Gated 커널 추가
- [axolotl] Axolotl: Triton 커널을 활용한 Entropy 및 Selective Log Softmax 최적화
PR Analysis 의 다른글
- 이전글 [triton] Custom DSL Plugin Ops 지원
- 현재글 : [Axolotl] ScatterMoE LoRA 최적화: 벤치마크, 커널 분할, autograd 통합
- 다음글 [axolotl] Gemma 3 QLoRA 설정 개선: Vision Tower 동결과 model_type 제거
댓글