[axolotl] Context Parallel 이중 시퀀스 분할 버그 수정: noop context manager로 중복 적용 방지
PR 링크: axolotl-ai-cloud/axolotl#3498 상태: Merged | 변경: +115 / -3
들어가며
Context Parallel(CP)은 긴 시퀀스를 여러 GPU에 분할하여 처리하는 기법입니다. axolotl은 자체적인 시퀀스 분할 로직(apply_sequence_parallelism)을 가지고 있지만, accelerate 라이브러리도 _prepare_cp 메서드를 통해 시퀀스 분할을 수행합니다. 이 두 메커니즘이 동시에 동작하면 시퀀스가 이중으로 분할되어 학습이 실패하는 버그가 발생했습니다.
핵심 코드 분석
accelerate의 _prepare_cp를 noop으로 대체
기존 코드는 accelerate의 _prepare_cp에서 PyTorch의 context_parallel context manager를 설정하고, 모델에 attention hook을 연결했습니다. axolotl이 자체 시퀀스 분할을 수행하므로 이 동작이 중복됩니다.
Before:
def patched_prepare_cp(self, *args):
if self.parallelism_config.cp_backend == "deepspeed":
return args
from accelerate.big_modeling import _attach_context_parallel_hooks
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
set_rotate_method(cp_comm_strategy)
self._cp_context = functools.partial(
context_parallel, mesh=self.torch_device_mesh["cp"]
)
for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)
return args
After:
def patched_prepare_cp(self, *args):
if self.parallelism_config.cp_backend == "deepspeed":
return args
@contextlib.contextmanager
def _noop_cp_context(
buffers=None, buffer_seq_dims=None, no_restore_buffers=None
):
yield
self._cp_context = _noop_cp_context
return args
핵심 변경점은 두 가지입니다:
context_parallelcontext manager를 아무 것도 하지 않는 noop으로 대체_attach_context_parallel_hooks호출을 완전히 제거
왜 이게 좋은가
이 수정은 "책임의 분리" 원칙을 따릅니다. axolotl은 자체 시퀀스 분할 메커니즘(SequenceParallelContextManager)을 가지고 있으므로, accelerate의 CP 로직은 비활성화하는 것이 올바른 접근입니다. noop context manager 패턴은 인터페이스 호환성을 유지하면서 동작을 무력화하는 깔끔한 방법입니다. _cp_context를 아예 None으로 설정하지 않고 noop으로 대체한 이유는, 호출부에서 with self._cp_context(...) 형태로 사용하기 때문에 context manager 프로토콜을 유지해야 하기 때문입니다.
정리
| 항목 | 내용 |
|---|---|
| 문제 | accelerate + axolotl 모두 시퀀스 분할을 수행하여 이중 적용 |
| 해결 | accelerate의 _prepare_cp를 noop context manager로 대체 |
| 영향 | Context Parallel 학습 안정성 확보 |
참고 자료
알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [axolotl] Context Parallelism batch_size 및 total_num_steps 계산 수정
- [axolotl] Tensor Parallelism batch_size 계산 버그 수정: dp_world_size 기반으로 전환
- [axolotl] FSDP CPU RAM Efficient Loading 패치: non-rank-0 프로세스의 불필요한 가중치 초기화 방지
- [axolotl] SchedulerMixin.create_scheduler() optimizer 누락 버그 수정
- [sglang] HiCache 메모리 누수 수정: host indices clone으로 참조 해제 보장
PR Analysis 의 다른글
- 이전글 [PaddleOCR] MCP 서버에서 모든 OCR 결과 배치를 파싱하도록 수정
- 현재글 : [axolotl] Context Parallel 이중 시퀀스 분할 버그 수정: noop context manager로 중복 적용 방지
- 다음글 [triton] Global Sanitizer에 TMA 및 cp.async 연산 부분 지원 추가
댓글