본문으로 건너뛰기

[axolotl] Flash Optimizer 지원 추가: FlashAdamW, FlashSGD, FlashLion 등 5종 커스텀 옵티마이저

PR 링크: axolotl-ai-cloud/axolotl#3457 상태: Merged | 변경: +115 / -3

들어가며

Flash Optimizer는 GPU 메모리 효율성과 학습 속도를 개선하기 위해 설계된 최적화된 옵티마이저 라이브러리입니다. 이 PR은 flashoptim 패키지의 5가지 옵티마이저(FlashAdamW, FlashAdam, FlashSGD, FlashSGDW, FlashLion)를 axolotl에 통합하고, DeepSpeed 비호환성과 FSDP 버전 제약에 대한 검증 로직을 추가합니다.

핵심 코드 분석

1. 옵티마이저 등록 및 구현

# src/axolotl/core/builders/base.py
elif self.cfg.optimizer == "flash_adamw":
    from flashoptim import FlashAdamW
    optimizer_cls = FlashAdamW
    optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_lion":
    from flashoptim import FlashLion
    optimizer_cls = FlashLion
    if "betas" in adam_kwargs:
        optimizer_kwargs["betas"] = adam_kwargs["betas"]

각 옵티마이저에 맞는 하이퍼파라미터만 전달합니다. FlashLion은 betas만 받고, FlashSGD/SGDw는 adam_kwargs를 전달하지 않습니다.

2. 호환성 검증

@model_validator(mode="before")
@classmethod
def check_flashoptim_deepspeed_fsdp(cls, data):
    optimizer = data.get("optimizer") or ""
    if str(optimizer).startswith("flash_"):
        if data.get("deepspeed"):
            raise ValueError(
                f"{optimizer} optimizer is incompatible with DeepSpeed."
            )
        if data.get("fsdp") or data.get("fsdp_config"):
            fsdp_version = cls._resolve_fsdp_version(data)
            if str(fsdp_version) != "2":
                raise ValueError(
                    f"{optimizer} optimizer is only compatible with FSDP2."
                )

Flash Optimizer는 DDP와 FSDP2만 지원하므로, DeepSpeed나 FSDP1 사용 시 설정 검증 단계에서 즉시 에러를 발생시킵니다. _resolve_fsdp_version 헬퍼를 추출하여 Muon 옵티마이저 검증과 코드를 공유합니다.

3. End-to-End 파라미터화 테스트

@pytest.mark.parametrize(
    "optimizer_name,expected_class,learning_rate",
    [
        ("flash_adamw", "FlashAdamW", 0.00001),
        ("flash_adam", "FlashAdam", 0.00001),
        ("flash_sgd", "FlashSGD", 0.01),
        ("flash_lion", "FlashLion", 0.0001),
    ],
)
def test_flash_optimizers(tmp_path, optimizer_name, expected_class, learning_rate):
    ...
    _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
    assert trainer.optimizer.optimizer.__class__.__name__ == expected_class

각 옵티마이저별로 적절한 learning rate를 설정하고, 실제 학습이 완료된 후 올바른 옵티마이저 클래스가 사용되었는지 검증합니다.

왜 이게 좋은가

새로운 옵티마이저 통합의 핵심은 호환성 경계의 명확한 정의입니다. Flash Optimizer가 FSDP2만 지원한다는 제약을 설정 검증 단계에서 catch하여, 런타임에 불투명한 에러가 발생하는 것을 방지합니다. 또한 _resolve_fsdp_version 헬퍼를 추출한 리팩토링은 Muon 옵티마이저와의 검증 로직 중복을 제거하면서도, 각 옵티마이저의 에러 메시지는 고유하게 유지합니다.

정리

옵티마이저 용도 특이사항
FlashAdamW 일반 학습 adam_kwargs 전체 전달
FlashAdam 일반 학습 adam_kwargs 전체 전달
FlashSGD 빠른 수렴 LR 0.01, adam_kwargs 미전달
FlashSGDW Weight Decay SGD adam_kwargs 미전달
FlashLion 메모리 효율 betas만 전달

참고 자료

알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글