본문으로 건너뛰기

[Axolotl] MXFP4 양자화 지원 추가

PR 링크: axolotl-ai-cloud/axolotl#3375 상태: Merged | 변경: +181 / -2

들어가며

LLM 파인튜닝에서 메모리 절약은 핵심 과제다. Axolotl은 이미 INT4, FP8, NVFP4 등의 Quantization-Aware Training(QAT)을 지원하고 있었다. 이 PR은 여기에 MXFP4(Microscaling FP4) 양자화를 추가한다. MXFP4는 NVIDIA Blackwell 아키텍처에서 지원하는 4비트 부동소수점 포맷으로, block_size=32 단위로 스케일링하여 정밀도 손실을 최소화한다.

핵심 코드 분석

1. Enum에 mxfp4 타입 등록

Before (enums.py):

class TorchAOQuantDType(Enum):
    int8 = torch.int8
    float8_e4m3fn = torch.float8_e4m3fn
    nvfp4 = "nvfp4"

After:

class TorchAOQuantDType(Enum):
    int8 = torch.int8
    float8_e4m3fn = torch.float8_e4m3fn
    nvfp4 = "nvfp4"
    mxfp4 = "mxfp4"

2. Quantization Config 생성 로직

기존 NVFP4는 group_size=16을 사용하지만, MXFP4는 block_size=32가 필수다. get_quantization_config()에 MXFP4 분기를 추가했다.

After (quantization.py):

if weight_dtype == TorchAOQuantDType.mxfp4:
    from torchao.prototype.qat import MXFakeQuantizeConfig

    block_size = group_size if group_size is not None else 32
    if block_size != 32:
        raise ValueError(
            "MXFP4 quantization must use a block_size (group_size) of 32"
        )

    return MXFakeQuantizeConfig(dtype=torch.float4_e2m1fn_x2, block_size=block_size)

torch.float4_e2m1fn_x2는 2비트 지수 + 1비트 가수의 FP4 포맷 2개를 묶은 타입이다.

3. QATConfig 분기 처리

MXFP4는 기존 QATConfig(base_config) 방식과 다르게 activation과 weight config를 분리하여 전달해야 한다.

Before:

qat_config = QATConfig(base_config)

After:

if isinstance(base_config, MXFakeQuantizeConfig):
    qat_config = QATConfig(
        activation_config=base_config,
        weight_config=base_config,
    )
else:
    qat_config = QATConfig(base_config)

이 분기가 필요한 이유는 MXFakeQuantizeConfig가 torchao의 prototype API에 속해 있어, 기존 QATConfig의 단일 인자 생성자와 호환되지 않기 때문이다. Embedding 레이어의 경우 activation 양자화가 불필요하므로 weight_config만 전달한다.

4. 예제 설정 파일

PR에는 Llama-3.2-3B 모델에 MXFP4 QAT를 적용하는 예제 YAML도 포함되어 있다:

qat:
  activation_dtype: mxfp4
  weight_dtype: mxfp4
  group_size: 32

왜 이게 좋은가

  • 4비트 정밀도: FP4 포맷으로 메모리 사용량을 INT8 대비 절반으로 줄이면서 QAT로 정확도를 보존한다
  • Blackwell 최적화: NVIDIA B200 등 최신 GPU에서 하드웨어 가속을 활용할 수 있다
  • 테스트 포함: test_qat.pytest_quantization.py에 schema 검증, config 생성, 모델 QAT 적용 테스트를 추가하여 안정성을 확보했다

정리

새로운 양자화 포맷을 프레임워크에 통합하는 전형적인 패턴을 보여주는 PR이다. (1) Enum 등록, (2) Config 팩토리 분기, (3) QAT 적용 분기, (4) 예제 + 테스트. 특히 torchao의 prototype API를 사용하면서 기존 인터페이스와의 호환성을 isinstance 분기로 해결한 점이 실용적이다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글