[axolotl] Context Parallelism batch_size 및 total_num_steps 계산 수정
PR 링크: axolotl-ai-cloud/axolotl#3444 상태: Merged | 변경: +64 / -11
들어가며
Context Parallelism(CP)은 긴 시퀀스를 여러 GPU에 분할하여 처리하는 기법으로, 각 CP 그룹의 GPU는 같은 배치의 서로 다른 시퀀스 조각을 처리합니다. 기존 코드는 batch_size 계산 시 world_size를 그대로 사용하고, total_num_steps에서 context_parallel_size를 곱하는 방식으로 보상하려 했지만, 이 접근이 근본적으로 잘못되어 있었습니다.
핵심 코드 분석
1. batch_size 계산: effective world_size 도입
Before:
cfg.batch_size = cfg.batch_size * cfg.world_size
After:
effective_world_size = cfg.world_size // (cfg.context_parallel_size or 1)
cfg.batch_size = cfg.batch_size * effective_world_size
4 GPU에서 CP=2이면, 실제 데이터 병렬 차원은 2(= 4 // 2)입니다. batch_size는 이 dp_world_size로만 스케일해야 합니다.
2. total_num_steps에서 CP size 곱셈 제거
Before:
total_num_steps = int(
math.ceil(
len(train_dataset) * cfg.num_epochs
* cfg.context_parallel_size * cfg.tensor_parallel_size
/ cfg.batch_size
)
)
After:
total_num_steps = int(
math.ceil(
len(train_dataset) * cfg.num_epochs
* cfg.tensor_parallel_size / cfg.batch_size
)
)
batch_size가 올바르게 계산되면, total_num_steps에서 CP size를 곱할 필요가 없습니다. sample packing 경로에서도 동일하게 CP size 곱셈을 제거했습니다.
3. 파라미터화된 테스트
@pytest.mark.parametrize(
"world_size, context_parallel_size, expected_batch_size",
[
(4, 1, 32), # no CP: 2*4*4 = 32
(4, 2, 16), # CP=2: 2*4*(4//2) = 16
(4, 4, 8), # CP=4: 2*4*(4//4) = 8
(2, 2, 8), # CP=ws: 2*4*(2//2) = 8
],
)
def test_batch_size_with_context_parallelism(self, ...):
왜 이게 좋은가
이전 접근법은 "잘못된 batch_size를 total_num_steps에서 보상"하는 방식이었는데, 이는 sample packing이나 다른 경로에서 일관되지 않은 결과를 만들었습니다. 새로운 접근법은 batch_size를 처음부터 올바르게 계산하여, 이후 모든 계산이 자연스럽게 정확해집니다. 이 PR은 이후 PR #3462(Tensor Parallelism batch_size 수정)의 기반이 되었습니다.
정리
| 항목 | 내용 |
|---|---|
| 문제 | CP 환경에서 batch_size 과대 계산 + total_num_steps 보상 방식 |
| 해결 | effective_world_size = world_size // CP, total_num_steps에서 CP 곱셈 제거 |
| 영향 | CP 학습 시 올바른 step 수와 학습률 스케줄 보장 |
참고 자료
알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [axolotl] Context Parallel 이중 시퀀스 분할 버그 수정: noop context manager로 중복 적용 방지
- [axolotl] Tensor Parallelism batch_size 계산 버그 수정: dp_world_size 기반으로 전환
- [axolotl] FSDP CPU RAM Efficient Loading 패치: non-rank-0 프로세스의 불필요한 가중치 초기화 방지
- [axolotl] SchedulerMixin.create_scheduler() optimizer 누락 버그 수정
- [sglang] HiCache 메모리 누수 수정: host indices clone으로 참조 해제 보장
PR Analysis 의 다른글
- 이전글 [axolotl] SchedulerMixin.create_scheduler() optimizer 누락 버그 수정
- 현재글 : [axolotl] Context Parallelism batch_size 및 total_num_steps 계산 수정
- 다음글 [Triton] FenceAsync에 비동기 읽기 의존성 추가 — st.shared와 copy_local_to_global 간 정합성 보장
댓글