[axolotl] Axolotl에 도입된 Stateless 최적화: SinkGD로 메모리 효율 극대화하기
PR 링크: axolotl-ai-cloud/axolotl#3763 상태: Merged | 변경: +614 / -2
들어가며
LLM 학습에서 가장 큰 병목 중 하나는 옵티마이저 상태(Optimizer State)가 차지하는 거대한 메모리 점유율입니다. 특히 AdamW를 사용할 경우 파라미터의 2배 이상을 상태값으로 저장해야 하므로, 대규모 모델 파인튜닝 시 메모리 부족 문제가 빈번하게 발생합니다. 최근 Axolotl에 추가된 SinkGD는 'Gradient Multi-Normalization' 기법을 통해 2D 가중치 행렬에 대해 상태를 저장하지 않는(stateless) 방식을 채택하여 이러한 문제를 획기적으로 해결합니다.
코드 분석
1. Stateless SinkGD 구현 (src/axolotl/utils/optimizers/sinkgd.py)
핵심은 sr_sinkhorn 함수입니다. 모멘텀이나 분산(variance)을 저장하는 대신, 그래디언트를 행과 열 방향으로 반복적으로 L2 정규화하여 업데이트 방향을 결정합니다.
def sr_sinkhorn(grad: Tensor, iters: int, eps: float) -> Tensor:
x = grad
for _ in range(iters):
x = x * (sqrt_n / x.norm(dim=-1, keepdim=True).clamp_min(eps))
x = x * (sqrt_m / x.norm(dim=-2, keepdim=True).clamp_min(eps))
return x
이 과정은 torch.compile을 통해 커널로 융합(fused)되어 실행되므로, 메모리 트래픽을 최소화하고 연산 효율을 극대화합니다.
2. 하이브리드 옵티마이저 전략
모든 파라미터에 SinkGD를 적용할 수는 없습니다. 임베딩이나 LM Head와 같은 1D 파라미터는 여전히 AdamW가 필요합니다. SinkGD는 이를 위해 torchao의 8-bit AdamW를 폴백(fallback)으로 사용하여 메모리 효율을 유지합니다.
# SinkGD 클래스 내부의 Adam 폴백 로직
def _adam_fallback(self, p: Tensor, grad: Tensor, group: dict, lr: Tensor) -> None:
# ... (생략) ...
self._compiled_adam(
p_local.detach(), grad_local, state["step"],
state["exp_avg"], state["exp_avg_sq"], ...
)
왜 이게 좋은가
- 메모리 절감: 8B 모델 풀 파인튜닝 기준, 2D 가중치에 대해 옵티마이저 상태를 0으로 만듦으로써 기존 8-bit AdamW 대비 약 87%의 옵티마이저 메모리를 절감합니다. 전체 피크 메모리 기준으로는 약 24% 감소 효과가 있습니다.
- 성능 유지:
torch.compile을 활용한 커널 융합 덕분에 end-to-end 처리량(throughput)이 기존 AdamW와 대등하거나 더 뛰어납니다. - 유연성:
sinkhorn_iters와sinkgd_lr_scale을 설정값으로 노출하여 모델의 특성에 맞게 정규화 강도를 조절할 수 있습니다.
교훈
대규모 모델 학습 시 모든 파라미터에 동일한 옵티마이저를 적용하는 것은 비효율적일 수 있습니다. 가중치의 형태(2D vs 1D)에 따라 최적의 업데이트 방식을 분리하는 하이브리드 접근법은 메모리 제약이 심한 환경에서 매우 강력한 전략이 됩니다.
리뷰어 피드백 분석
이번 PR에서는 BaseOptimizerFactory 대신 transformers의 is_optimizer_factory를 활용하도록 리팩토링되었습니다. 이는 Axolotl의 특정 구현에 의존하지 않고 Hugging Face 생태계와의 호환성을 높이는 올바른 방향입니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [axolotl] Axolotl, 대규모 언어 모델 학습 시 메모리 부족 문제 해결: 효율적인 데이터셋 처리 개선
- [axolotl] Axolotl, Marlin W4A16 도입으로 MoE 모델 추론 속도 1.79배 향상 및 품질 개선
- [sglang] SGLang 성능 최적화: torch.cuda.empty_cache() 호출 제어를 통한 가중치 업데이트 병목 해결
- [ACE-Step-1.5] 외부 의존성을 걷어내고 성능을 잡다: ACE-Step 1.5의 커스텀 vLLM 엔진 도입기
- [onnxruntime] ONNX Runtime QMoE SwiGLU GEMV 최적화: Split-K2 커널로 LLM 추론 가속화
PR Analysis 의 다른글
- 이전글 [sglang] [NPU] GLM-4.7-Flash 성능 최적화: Fused Triton 커널로 연산 병목 해결하기
- 현재글 : [axolotl] Axolotl에 도입된 Stateless 최적화: SinkGD로 메모리 효율 극대화하기
- 다음글 없음
댓글