본문으로 건너뛰기

[axolotl] FSDP CPU RAM Efficient Loading 패치: non-rank-0 프로세스의 불필요한 가중치 초기화 방지

PR 링크: axolotl-ai-cloud/axolotl#3464 상태: Merged | 변경: +40 / -0

들어가며

FSDP(Fully Sharded Data Parallel)에서 cpu_ram_efficient_loading을 사용하면, rank 0만 실제 가중치를 로드하고 나머지 프로세스는 meta device에 빈 텐서를 생성합니다. 그런데 transformers의 _initialize_missing_keys가 모든 rank에서 가중치를 다시 초기화하여, 불필요한 메모리 사용과 성능 저하를 유발하는 문제가 있었습니다.

핵심 코드 분석

_initialize_missing_keys 패치

def patch_initialize_missing_keys_for_fsdp():
    from transformers import PreTrainedModel
    from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0

    _original_initialize_missing_keys = PreTrainedModel._initialize_missing_keys

    def _patched_initialize_missing_keys(self, is_quantized: bool) -> None:
        if is_fsdp_enabled() and not is_local_dist_rank_0():
            for key in self.state_dict():
                param_or_buffer = self.get_parameter_or_buffer(key)
                param_or_buffer._is_hf_initialized = True
            self._is_hf_initialized = True

        _original_initialize_missing_keys(self, is_quantized)

    PreTrainedModel._initialize_missing_keys = _patched_initialize_missing_keys

핵심 아이디어는 non-rank-0 프로세스에서 모든 파라미터에 _is_hf_initialized = True 플래그를 설정하는 것입니다. transformers의 guarded init 함수(init.normal_, init.zeros_ 등)는 이 플래그가 True인 파라미터를 건너뛰므로, 실질적으로 재초기화가 no-op이 됩니다. 실제 가중치는 나중에 FSDP의 rank 0 broadcast를 통해 전달됩니다.

패치 적용 위치

def _apply_fsdp_patches(self):
    if self.cfg.fsdp_config:
        from axolotl.monkeypatch.accelerate.fsdp2 import (
            patch_initialize_missing_keys_for_fsdp,
        )
        patch_initialize_missing_keys_for_fsdp()

fsdp_config가 설정된 경우에만 패치를 적용하며, FSDP2뿐만 아니라 FSDP1에서도 동작합니다.

왜 이게 좋은가

이 패치의 가치는 메모리 효율성에 있습니다. 대형 모델(70B+ 파라미터)을 FSDP로 학습할 때, non-rank-0 프로세스에서 불필요한 가중치 초기화는 프로세스당 수 GB의 추가 메모리를 소비할 수 있습니다. _is_hf_initialized 플래그를 활용한 해결 방법은 transformers의 기존 메커니즘을 그대로 활용하여, 최소한의 코드 변경으로 문제를 해결합니다. PR 설명에 upstream fix 링크(transformers#44473)를 명시하여, 향후 패치 제거 시점도 명확히 한 점이 좋습니다.

정리

항목 내용
문제 non-rank-0에서 불필요한 가중치 재초기화
해결 _is_hf_initialized 플래그 설정으로 init 건너뛰기
영향 FSDP 분산 학습 시 메모리 사용량 감소, 초기화 속도 향상

참고 자료

알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글