[axolotl] FSDP CPU RAM Efficient Loading 패치: non-rank-0 프로세스의 불필요한 가중치 초기화 방지
PR 링크: axolotl-ai-cloud/axolotl#3464 상태: Merged | 변경: +40 / -0
들어가며
FSDP(Fully Sharded Data Parallel)에서 cpu_ram_efficient_loading을 사용하면, rank 0만 실제 가중치를 로드하고 나머지 프로세스는 meta device에 빈 텐서를 생성합니다. 그런데 transformers의 _initialize_missing_keys가 모든 rank에서 가중치를 다시 초기화하여, 불필요한 메모리 사용과 성능 저하를 유발하는 문제가 있었습니다.
핵심 코드 분석
_initialize_missing_keys 패치
def patch_initialize_missing_keys_for_fsdp():
from transformers import PreTrainedModel
from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0
_original_initialize_missing_keys = PreTrainedModel._initialize_missing_keys
def _patched_initialize_missing_keys(self, is_quantized: bool) -> None:
if is_fsdp_enabled() and not is_local_dist_rank_0():
for key in self.state_dict():
param_or_buffer = self.get_parameter_or_buffer(key)
param_or_buffer._is_hf_initialized = True
self._is_hf_initialized = True
_original_initialize_missing_keys(self, is_quantized)
PreTrainedModel._initialize_missing_keys = _patched_initialize_missing_keys
핵심 아이디어는 non-rank-0 프로세스에서 모든 파라미터에 _is_hf_initialized = True 플래그를 설정하는 것입니다. transformers의 guarded init 함수(init.normal_, init.zeros_ 등)는 이 플래그가 True인 파라미터를 건너뛰므로, 실질적으로 재초기화가 no-op이 됩니다. 실제 가중치는 나중에 FSDP의 rank 0 broadcast를 통해 전달됩니다.
패치 적용 위치
def _apply_fsdp_patches(self):
if self.cfg.fsdp_config:
from axolotl.monkeypatch.accelerate.fsdp2 import (
patch_initialize_missing_keys_for_fsdp,
)
patch_initialize_missing_keys_for_fsdp()
fsdp_config가 설정된 경우에만 패치를 적용하며, FSDP2뿐만 아니라 FSDP1에서도 동작합니다.
왜 이게 좋은가
이 패치의 가치는 메모리 효율성에 있습니다. 대형 모델(70B+ 파라미터)을 FSDP로 학습할 때, non-rank-0 프로세스에서 불필요한 가중치 초기화는 프로세스당 수 GB의 추가 메모리를 소비할 수 있습니다. _is_hf_initialized 플래그를 활용한 해결 방법은 transformers의 기존 메커니즘을 그대로 활용하여, 최소한의 코드 변경으로 문제를 해결합니다. PR 설명에 upstream fix 링크(transformers#44473)를 명시하여, 향후 패치 제거 시점도 명확히 한 점이 좋습니다.
정리
| 항목 | 내용 |
|---|---|
| 문제 | non-rank-0에서 불필요한 가중치 재초기화 |
| 해결 | _is_hf_initialized 플래그 설정으로 init 건너뛰기 |
| 영향 | FSDP 분산 학습 시 메모리 사용량 감소, 초기화 속도 향상 |
참고 자료
알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [논문리뷰] veScale-FSDP: Flexible and High-Performance FSDP at Scale
- [axolotl] Context Parallel 이중 시퀀스 분할 버그 수정: noop context manager로 중복 적용 방지
- [axolotl] Tensor Parallelism batch_size 계산 버그 수정: dp_world_size 기반으로 전환
- [Axolotl] 가중치 동기 로딩으로 OOM 방지
- [axolotl] Context Parallelism batch_size 및 total_num_steps 계산 수정
PR Analysis 의 다른글
- 이전글 [vllm] FlashInfer MoE A2A Kernel - NVLink 기반 Expert Parallelism 통신
- 현재글 : [axolotl] FSDP CPU RAM Efficient Loading 패치: non-rank-0 프로세스의 불필요한 가중치 초기화 방지
- 다음글 [Ray Core] OOM Killer에서 대용량 메모리를 점유한 유휴 워커를 우선 종료
댓글