본문으로 건너뛰기

[axolotl] Context Parallelism batch_size 및 total_num_steps 계산 수정

PR 링크: axolotl-ai-cloud/axolotl#3444 상태: Merged | 변경: +64 / -11

들어가며

Context Parallelism(CP)은 긴 시퀀스를 여러 GPU에 분할하여 처리하는 기법으로, 각 CP 그룹의 GPU는 같은 배치의 서로 다른 시퀀스 조각을 처리합니다. 기존 코드는 batch_size 계산 시 world_size를 그대로 사용하고, total_num_steps에서 context_parallel_size를 곱하는 방식으로 보상하려 했지만, 이 접근이 근본적으로 잘못되어 있었습니다.

핵심 코드 분석

1. batch_size 계산: effective world_size 도입

Before:

cfg.batch_size = cfg.batch_size * cfg.world_size

After:

effective_world_size = cfg.world_size // (cfg.context_parallel_size or 1)
cfg.batch_size = cfg.batch_size * effective_world_size

4 GPU에서 CP=2이면, 실제 데이터 병렬 차원은 2(= 4 // 2)입니다. batch_size는 이 dp_world_size로만 스케일해야 합니다.

2. total_num_steps에서 CP size 곱셈 제거

Before:

total_num_steps = int(
    math.ceil(
        len(train_dataset) * cfg.num_epochs
        * cfg.context_parallel_size * cfg.tensor_parallel_size
        / cfg.batch_size
    )
)

After:

total_num_steps = int(
    math.ceil(
        len(train_dataset) * cfg.num_epochs
        * cfg.tensor_parallel_size / cfg.batch_size
    )
)

batch_size가 올바르게 계산되면, total_num_steps에서 CP size를 곱할 필요가 없습니다. sample packing 경로에서도 동일하게 CP size 곱셈을 제거했습니다.

3. 파라미터화된 테스트

@pytest.mark.parametrize(
    "world_size, context_parallel_size, expected_batch_size",
    [
        (4, 1, 32),   # no CP: 2*4*4 = 32
        (4, 2, 16),   # CP=2: 2*4*(4//2) = 16
        (4, 4, 8),    # CP=4: 2*4*(4//4) = 8
        (2, 2, 8),    # CP=ws: 2*4*(2//2) = 8
    ],
)
def test_batch_size_with_context_parallelism(self, ...):

왜 이게 좋은가

이전 접근법은 "잘못된 batch_size를 total_num_steps에서 보상"하는 방식이었는데, 이는 sample packing이나 다른 경로에서 일관되지 않은 결과를 만들었습니다. 새로운 접근법은 batch_size를 처음부터 올바르게 계산하여, 이후 모든 계산이 자연스럽게 정확해집니다. 이 PR은 이후 PR #3462(Tensor Parallelism batch_size 수정)의 기반이 되었습니다.

정리

항목 내용
문제 CP 환경에서 batch_size 과대 계산 + total_num_steps 보상 방식
해결 effective_world_size = world_size // CP, total_num_steps에서 CP 곱셈 제거
영향 CP 학습 시 올바른 step 수와 학습률 스케줄 보장

참고 자료

알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글