[axolotl] ScatterMoE LoRA 최적화: Grouped-Gram 및 Sync-free 역전파 구현
PR 링크: axolotl-ai-cloud/axolotl#3712 상태: Merged | 변경: +383 / -31
들어가며
최근 대규모 Mixture-of-Experts (MoE) 모델이 주류가 되면서, E(Expert) 수가 128개 이상인 모델에서의 LoRA 학습 효율성이 중요한 과제로 떠올랐습니다. 기존의 scattermoe_lora 구현체는 전문가별로 루프를 돌며 expert_offsets[e].item()을 호출하는 과정에서 빈번한 호스트-디바이스 간 동기화(sync)가 발생하여, 전문가 수가 늘어날수록 성능이 급격히 저하되는 문제가 있었습니다. 본 PR은 이러한 병목을 해결하기 위해 Grouped-Gram 연산을 도입하고, 중간 결과물(XA, YB)을 재사용하여 불필요한 재계산을 제거하는 최적화를 수행했습니다.
코드 분석
1. grouped_gram.py: Grouped-Gram 커널 도입
기존 방식은 출력 차원 블록마다 XA(X@A^T)나 YB(dY@B)를 매번 재계산했습니다. 새로운 _grouped_gram_kernel은 이를 한 번만 계산하여 메모리에 유지함으로써, 고차원 전문가 환경에서 연산 효율을 극대화합니다.
# Before: Per-output-block recompute (implicit in old split kernel)
# After: Precomputed XA/YB used in grouped Gram product
acc += tl.dot(tl.trans(p), q, allow_tf32=allow_tf32)
2. parallel_linear_lora.py: 역전파 경로 최적화
backward 함수 내에서 yb와 xa를 미리 계산하여 grouped_lora_weight_grads에 전달합니다. 특히 dX_lora 계산 시 기존의 전문가별 루프를 제거하고, scatter2scatter를 활용한 동기화 없는(sync-free) 경로를 구현했습니다.
# Before: Per-expert loop with O(E) host syncs
# After: Single launch grouped GEMMs reusing precomputed YB
d_lora_A, d_lora_B = grouped_lora_weight_grads(
grouped_grad_out, grouped_x, yb, xa, lora_A, lora_B, expert_offsets, E, scaling
)
왜 이게 좋은가
이번 최적화의 핵심은 '불필요한 동기화 제거'와 '중간값 재사용'입니다.
- 성능 향상: Qwen3-MoE 및 DeepSeek와 같은 대규모 모델 환경에서 비융합(non-fused) 경로 기준 최대 2.2배의 성능 향상을 보였습니다.
dA/dB커널 단독으로는 2~17배의 속도 개선이 확인되었습니다. - 확장성: 전문가 수(E)가 128개 이상으로 증가해도, 호스트 동기화가 O(E)에서 O(1)로 줄어들어 선형적인 성능 저하를 방지합니다.
- 교훈: GPU 연산에서 메모리 대역폭보다 호스트-디바이스 간의 동기화(Sync)가 더 큰 병목이 될 수 있음을 보여줍니다. 중간값(Intermediate)을 저장하는 메모리 오버헤드(+5~30MB)가 연산 속도 향상이라는 이득보다 훨씬 작다면, 적극적으로 캐싱하는 전략이 유효합니다.
결론
이번 PR은 대규모 MoE LoRA 학습의 고질적인 병목을 성공적으로 해결했습니다. 특히 Triton을 활용한 커널 최적화와 파이썬 레벨의 동기화 제거는 대규모 모델 학습 파이프라인 설계에 중요한 레퍼런스가 될 것입니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang: Triton 버전 업그레이드에 따른 MoE 성능 회귀 해결 및 설정 자동화
- [Axolotl] ScatterMoE LoRA 최적화: 벤치마크, 커널 분할, autograd 통합
- [논문리뷰] On the Scaling of PEFT: Towards Million Personal Models of Trillion Parameters
- [vllm] vLLM의 MoE Permute 최적화: 버퍼 사전 할당을 통한 성능 향상
- [LlamaFactory] LlamaFactory의 Triton 기반 Fused MoE 커널 도입: 40% 이상의 성능 향상
PR Analysis 의 다른글
- 이전글 [cpython] Python re 모듈의 findall, sub, subn 성능 개선: PyList_AppendTakeRef 도입
- 현재글 : [axolotl] ScatterMoE LoRA 최적화: Grouped-Gram 및 Sync-free 역전파 구현
- 다음글 없음
댓글