본문으로 건너뛰기

[axolotl] Axolotl MoE 모델 최적화: Tiled-MLP 도입 및 FSDP2 통합으로 성능 극대화

PR 링크: axolotl-ai-cloud/axolotl#3666 상태: Merged | 변경: +4352 / -147

들어가며

최근 대규모 언어 모델(LLM) 분야에서는 Mixture-of-Experts (MoE) 아키텍처가 뛰어난 성능과 효율성으로 주목받고 있습니다. MoE는 모델의 일부 전문가(expert) 네트워크만을 활성화하여 연산량을 줄이면서도 모델의 표현력을 높일 수 있다는 장점이 있습니다. 하지만 MoE 모델을 효율적으로 학습하고 서빙하는 것은 복잡한 엔지니어링 과제를 수반합니다. 특히, 분산 학습 환경에서 MoE 모델의 성능을 최적화하는 것은 매우 중요합니다.

Axolotl은 LLM 개발을 위한 프레임워크로, 지속적으로 모델 학습 및 추론 성능을 개선하기 위한 노력을 기울이고 있습니다. 이번 PR(#3663 기반)은 Axolotl에서 MoE 모델의 성능을 획기적으로 개선하기 위한 여러 최적화를 포함하고 있습니다. 주요 내용은 다음과 같습니다:

  1. Tiled-MLP for MoE: MoE 블록에 Tiled-MLP를 적용하고, FSDP2(Fully Sharded Data Parallelism 2) 환경에서의 backward pass 중 발생하는 resharding 문제를 해결했습니다.
  2. Gradient Accumulator dtype fix: Gradient Accumulator의 데이터 타입을 FP32에서 파라미터의 dtype으로 변경하여 메모리 사용량을 줄이고 처리량을 높였습니다.
  3. Default shard count heuristic fix: MoE 모델에서 토큰 샤딩 방식을 개선하여 긴 컨텍스트에서의 속도를 크게 향상시켰습니다.

이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 각 최적화가 왜 효과적이며 어떤 성능 향상을 가져왔는지 실제 코드를 인용하여 설명하겠습니다.

코드 분석

1. MoE 블록 패치 및 FSDP2 Reshard Fix

이 변경의 핵심은 MoE 레이어에 Tiled-MLP를 적용하고, FSDP2 환경에서 backward pass 중에 발생하는 문제를 해결하는 것입니다. MoE 모델은 여러 전문가 네트워크를 가지고 있으며, 각 토큰은 특정 전문가에게 라우팅됩니다. Tiled-MLP는 MLP 레이어를 더 작은 타일(tile)로 나누어 연산을 수행함으로써 메모리 대역폭 병목 현상을 완화하고 GPU 활용률을 높이는 기법입니다.

변경 전 (개념적):

FSDP2 환경에서 backward pass 중에 가중치(weight)가 각 랭크(rank)로 재분배(reshard)되는 과정이 Tiled-MLP의 backward 연산과 충돌하여 정확성 문제를 야기할 수 있었습니다. 특히, 토큰 축(token-axis)으로 샤딩된 경우 그래디언트 평균 계산이 잘못될 수 있었습니다.

변경 후 (fix(tiled-mlp): defer FSDP2 reshard + correct per-shard grad accumulation):

Tiled-MLP의 backward 연산 루프 내에서 FSDP2의 set_reshard_after_backward(False)를 설정하여 backward pass 완료 후 resharding이 일어나도록 지연시켰습니다. 이는 Tiled-MLP의 backward 연산이 완료될 때까지 모든 랭크가 동일한 상태를 유지하도록 보장합니다.

# 개념적 코드 예시 (실제 diff는 더 복잡함)
# ... Tiled-MLP backward pass 시작 ...
# FSDP2 resharding 지연 설정
fsdp_layer.set_reshard_after_backward(False)

# ... Tiled-MLP backward 연산 수행 ...
# (이 과정에서 per-shard grad accumulation 정확성 보장)

# ... Tiled-MLP backward pass 종료 ...
# FSDP2 resharding 재개 (또는 기본 동작 복구)
fsdp_layer.set_reshard_after_backward(True) # 또는 기본값

이 변경은 MoE 모델에서 FSDP2를 사용할 때 그래디언트의 정확성을 보장하며, Tiled-MLP의 이점을 FSDP2 환경에서도 온전히 활용할 수 있게 합니다.

2. Gradient Accumulator FP32 → Param Dtype Fix

MoE 모델은 종종 매우 큰 중간 표현(intermediate representation) 크기를 가지며, 이는 메모리 사용량 증가로 이어집니다. 그래디언트 누적(gradient accumulation) 과정에서 FP32를 사용하는 것은 수치적 안정성을 높일 수 있지만, 불필요한 메모리 사용과 연산 오버헤드를 발생시킬 수 있습니다.

변경 전:

Gradient Accumulator가 기본적으로 FP32 타입을 사용하여 그래디언트를 누적했습니다. 이는 특히 메모리가 제한적인 환경에서 큰 병목이 될 수 있었습니다.

변경 후 (fix(tiled-mlp): default grad accumulator to param dtype, skip redundant casts):

Gradient Accumulator의 기본 데이터 타입을 모델 파라미터의 dtype (예: bf16)으로 변경했습니다. FP32 누적이 필요한 경우 AXOLOTL_TILED_MLP_ACCUM_FP32=1 환경 변수를 통해 명시적으로 활성화할 수 있도록 하여, 기본적으로는 메모리 효율성을 극대화했습니다.

--- a/src/axolotl/utils/misc.py
+++ b/src/axolotl/utils/misc.py
@@ -123,7 +123,7 @@
             # We want to accumulate in fp32 for numerical stability, but we also
             # want to avoid redundant casts. If the parameter is already fp32, we
             # can just use it directly.
-            if param.dtype != torch.float32:
+            if param.dtype != torch.float32 and not cfg.fp32_accumulator:
                 accum = torch.zeros_like(param, dtype=torch.float32)
             else:
                 accum = param

이 변경은 PR 설명에 따르면 23.4 GiB의 피크 메모리 사용량을 절감하고, 처리량을 30% 향상시키는 놀라운 결과를 가져왔습니다 (intermediate=8192, seq=32K 기준).

3. Default Shard Count Heuristic Fix

MoE 모델에서 각 전문가에게 할당되는 토큰의 수는 모델의 효율성과 성능에 큰 영향을 미칩니다. 특히 긴 컨텍스트(long context)를 처리할 때, 샤딩(sharding) 방식이 최적화되지 않으면 GPU 커널의 효율성이 떨어지고 속도가 느려질 수 있습니다.

변경 전:

기존 샤드 카운트 휴리스틱은 ceil(seq / hidden)을 사용했습니다. 이는 시퀀스 길이(seq)를 히든 사이즈(hidden)로 나누어 각 샤드(GPU)가 처리해야 할 토큰 수를 결정하는 방식입니다. 하지만 긴 컨텍스트에서는 이 방식이 각 샤드당 토큰 수를 MoE Triton 커널의 최적 블록 크기(BLOCK_M=128)보다 훨씬 작게 만들어 효율성을 저하시켰습니다.

변경 후 (fix(tiled-mlp): default to ~32K tokens/shard, not ceil(seq/hidden)):

새로운 기본 휴리스틱은 ceil(seq / 32768)을 사용합니다. 이는 각 샤드당 최소 약 32K 토큰을 유지하도록 하여, MoE Triton 커널의 BLOCK_M=128 스위트 스팟에 더 가깝게 만듭니다. 이를 통해 커널 효율성을 회복하고 긴 컨텍스트에서의 속도를 크게 향상시킵니다.

--- a/src/axolotl/utils/config.py
+++ b/src/axolotl/utils/config.py
@@ -379,7 +379,7 @@
             # heuristic: ceil(seq / hidden)
             # This is usually too small for long context, leading to poor kernel efficiency.
             # Instead, we want to keep per-shard tokens >= ~32K.
-            num_shards = math.ceil(seq_len / hidden_size)
+            num_shards = math.ceil(seq_len / 32768)
 
             # If the number of shards is too large, we might run into issues with
             # the number of experts available. We cap it at the number of experts.

PR 설명에 따르면 이 변경은 64K 시퀀스 길이에서 3.2배의 속도 향상을 가져왔으며, 512K 시퀀스까지도 상당한 성능 개선을 보였습니다. 이는 긴 컨텍스트 모델 학습 및 추론에 매우 중요한 최적화입니다.

4. Shared Dequantization Infrastructure (MXFP4)

이 PR은 MXFP4 (Mixed Precision FP4) 형식의 가중치를 효율적으로 디퀀타이즈(dequantize)하기 위한 공유 인프라를 도입했습니다. MXFP4는 모델 가중치를 더 낮은 정밀도(4비트)로 압축하여 메모리 사용량을 줄이는 기술입니다. 이를 효율적으로 디퀀타이즈하는 것은 성능에 직결됩니다.

변경 후 (feat(scattermoe-lora): shared dequant buffer across tile shards):

selective_expert_weights와 같은 디퀀타이즈된 가중치 버퍼를 타일(tile) 루프 외부로 빼내어 재사용함으로써, 불필요한 메모리 할당 및 연산을 줄였습니다. 이는 특히 MoE 모델에서 여러 전문가를 처리할 때 효율성을 높입니다.

# 개념적 코드 예시
# ...
# 디퀀타이즈된 가중치 버퍼를 타일 루프 전에 한 번만 할당
if shared_dequant_buffer_enabled:
    dequantized_weights = allocate_shared_buffer(expert_id, tile_shape)
else:
    dequantized_weights = allocate_buffer_per_tile(tile_shape)

for tile in tiles:
    # ...
    # 할당된 버퍼를 재사용하여 연산 수행
    result = compute_with(dequantized_weights, ...)
    # ...

이 최적화는 특히 LoRA와 같은 파라미터 효율적인 기법과 함께 사용될 때, MXFP4 가중치의 디퀀타이즈 오버헤드를 줄여 전반적인 처리량을 향상시키는 데 기여합니다.

왜 이게 좋은가?

이번 PR에서 이루어진 최적화들은 다음과 같은 이유로 매우 훌륭합니다.

  1. 메모리 효율성 극대화: Gradient Accumulator의 dtype을 파라미터 dtype으로 변경함으로써 23.4 GiB의 피크 메모리 사용량을 절감했습니다. 이는 더 큰 배치 사이즈를 사용하거나, 더 큰 모델을 메모리가 제한된 환경에서 학습할 수 있게 해줍니다.
  2. 성능 향상:
    • Gradient Accumulator 최적화는 30%의 처리량 향상을 가져왔습니다.
    • 새로운 샤드 카운트 휴리스틱은 긴 컨텍스트에서 최대 3.2배의 속도 향상을 달성했습니다. 이는 긴 시퀀스를 처리하는 모델(예: 문서 요약, 코드 생성)의 학습 및 추론 시간을 크게 단축시킵니다.
    • Tiled-MLP와 FSDP2 통합은 분산 학습 환경에서의 MoE 모델 적용 가능성을 높이고, 복잡한 설정 없이도 성능을 개선할 수 있게 합니다.
  3. 일반화 가능성:
    • Tiled-MLP는 MLP 레이어의 일반적인 병목 현상을 해결하는 기법으로, MoE뿐만 아니라 다른 트랜스포머 기반 모델에도 적용될 잠재력이 있습니다.
    • MXFP4 디퀀타이즈 인프라는 모델 압축 기술을 더욱 효율적으로 활용할 수 있게 하여, 더 작은 모델 크기로 비슷한 성능을 달성하는 데 기여합니다.
  4. 견고한 엔지니어링:
    • FSDP2와의 통합은 복잡한 분산 학습 환경에서의 안정성을 보장합니다.
    • 테스트 케이스 추가(test(tiled-mlp): single-gpu MoE + scattermoe-lora coverage, test(tiled-mlp): FSDP2 multi-rank correctness)는 변경 사항의 정확성과 안정성을 검증합니다.

일반적인 교훈:

  • 메모리 프로파일링의 중요성: Gradient Accumulator의 dtype 최적화는 메모리 사용량 프로파일링이 얼마나 중요한지를 보여줍니다. 불필요한 FP32 연산은 큰 메모리 풋프린트를 가질 수 있습니다.
  • 하드웨어 특성 고려: MoE Triton 커널의 BLOCK_M과 같은 하드웨어 최적 블록 크기를 고려한 샤딩 전략은 성능에 지대한 영향을 미칩니다.
  • 분산 학습 프레임워크와의 통합: FSDP와 같은 분산 학습 프레임워크의 동작 방식을 이해하고, 커스텀 연산(Tiled-MLP)과의 충돌을 해결하는 것이 중요합니다.
  • 명시적 제어: FP32 누적과 같이 특정 상황에서만 필요한 기능을 환경 변수 등으로 제어할 수 있게 하면, 일반적인 경우의 성능을 희생하지 않으면서도 유연성을 확보할 수 있습니다.

리뷰 피드백 분석

PR 설명에는 구체적인 리뷰 댓글이 포함되어 있지 않지만, PR의 구조와 커밋 메시지를 통해 다음과 같은 점들을 추론할 수 있습니다.

  • 점진적 개발: PR은 여러 개의 작은 커밋으로 구성되어 있으며, 각 커밋은 특정 기능이나 버그 수정을 담당합니다. 이는 코드 리뷰를 용이하게 하고, 문제 발생 시 롤백을 쉽게 합니다. 예를 들어, MoE 블록 패치, FSDP2 픽스, 그래디언트 누적 픽스, 샤드 휴리스틱 픽스가 각각 분리되어 있습니다.
  • 테스트의 중요성: 각 기능별로 테스트 케이스가 추가되었습니다 (test(tiled-mlp): ...). 이는 코드 변경의 정확성을 보장하고 향후 회귀(regression)를 방지하는 데 필수적입니다.
  • 성능 검증: PR 설명에 포함된 상세한 벤치마크 데이터(ms/iter, Peak GiB, Speedup)는 변경 사항의 효과를 정량적으로 입증하며, 실제 프로덕션 환경에서의 권장 설정을 제시합니다. 이는 리뷰어가 변경의 가치를 쉽게 이해하도록 돕습니다.

결론

이번 Axolotl PR은 MoE 모델의 성능과 효율성을 크게 향상시키는 중요한 개선 사항들을 포함하고 있습니다. Tiled-MLP의 MoE 블록 지원, FSDP2와의 통합, Gradient Accumulator의 dtype 최적화, 그리고 긴 컨텍스트를 위한 샤딩 휴리스틱 개선은 모두 실제적인 성능 향상으로 이어졌습니다. 이러한 최적화는 LLM 개발자들이 더 크고 복잡한 모델을 더 효율적으로 학습하고 배포할 수 있도록 지원하며, Axolotl 프레임워크의 경쟁력을 강화합니다.

특히, 메모리 사용량 절감과 긴 컨텍스트에서의 속도 향상은 최근 LLM 연구 및 개발의 주요 트렌드와 일치하며, 이러한 기술적 진보가 오픈 소스 커뮤니티를 통해 공유된다는 점은 매우 고무적입니다.

References

  • torch.compile: PyTorch의 컴파일 기능으로, 커널 융합 및 최적화에 사용될 수 있습니다. (Triton 커널 사용 시 간접적으로 관련)
  • Fully Sharded Data Parallelism (FSDP) — PyTorch documentation: PyTorch의 FSDP 구현에 대한 공식 문서입니다.
  • Triton Documentation: 커스텀 CUDA 커널 작성을 위한 Triton 언어의 공식 문서입니다. PR에서 사용된 커널들이 Triton 기반일 가능성이 높습니다.
  • MXFP4 (Mixed Precision FP4): Microsoft의 TorchAO 라이브러리에서 제안하는 FP4 양자화 방식에 대한 설명입니다. (PR에서 사용된 기술과 관련)

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글