본문으로 건너뛰기

[axolotl] Context Parallel 이중 시퀀스 분할 버그 수정: noop context manager로 중복 적용 방지

PR 링크: axolotl-ai-cloud/axolotl#3498 상태: Merged | 변경: +115 / -3

들어가며

Context Parallel(CP)은 긴 시퀀스를 여러 GPU에 분할하여 처리하는 기법입니다. axolotl은 자체적인 시퀀스 분할 로직(apply_sequence_parallelism)을 가지고 있지만, accelerate 라이브러리도 _prepare_cp 메서드를 통해 시퀀스 분할을 수행합니다. 이 두 메커니즘이 동시에 동작하면 시퀀스가 이중으로 분할되어 학습이 실패하는 버그가 발생했습니다.

핵심 코드 분석

accelerate의 _prepare_cp를 noop으로 대체

기존 코드는 accelerate의 _prepare_cp에서 PyTorch의 context_parallel context manager를 설정하고, 모델에 attention hook을 연결했습니다. axolotl이 자체 시퀀스 분할을 수행하므로 이 동작이 중복됩니다.

Before:

def patched_prepare_cp(self, *args):
    if self.parallelism_config.cp_backend == "deepspeed":
        return args

    from accelerate.big_modeling import _attach_context_parallel_hooks
    from torch.distributed.tensor.experimental import context_parallel
    from torch.distributed.tensor.experimental._attention import set_rotate_method

    cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
    set_rotate_method(cp_comm_strategy)

    self._cp_context = functools.partial(
        context_parallel, mesh=self.torch_device_mesh["cp"]
    )

    for arg in args:
        if isinstance(arg, torch.nn.Module):
            _attach_context_parallel_hooks(arg)

    return args

After:

def patched_prepare_cp(self, *args):
    if self.parallelism_config.cp_backend == "deepspeed":
        return args

    @contextlib.contextmanager
    def _noop_cp_context(
        buffers=None, buffer_seq_dims=None, no_restore_buffers=None
    ):
        yield

    self._cp_context = _noop_cp_context
    return args

핵심 변경점은 두 가지입니다:

  1. context_parallel context manager를 아무 것도 하지 않는 noop으로 대체
  2. _attach_context_parallel_hooks 호출을 완전히 제거

왜 이게 좋은가

이 수정은 "책임의 분리" 원칙을 따릅니다. axolotl은 자체 시퀀스 분할 메커니즘(SequenceParallelContextManager)을 가지고 있으므로, accelerate의 CP 로직은 비활성화하는 것이 올바른 접근입니다. noop context manager 패턴은 인터페이스 호환성을 유지하면서 동작을 무력화하는 깔끔한 방법입니다. _cp_context를 아예 None으로 설정하지 않고 noop으로 대체한 이유는, 호출부에서 with self._cp_context(...) 형태로 사용하기 때문에 context manager 프로토콜을 유지해야 하기 때문입니다.

정리

항목 내용
문제 accelerate + axolotl 모두 시퀀스 분할을 수행하여 이중 적용
해결 accelerate의 _prepare_cp를 noop context manager로 대체
영향 Context Parallel 학습 안정성 확보

참고 자료

알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글