본문으로 건너뛰기

[axolotl] Tensor Parallelism batch_size 계산 버그 수정: dp_world_size 기반으로 전환

PR 링크: axolotl-ai-cloud/axolotl#3462 상태: Merged | 변경: +61 / -11

들어가며

Tensor Parallelism(TP)은 하나의 모델 레이어를 여러 GPU에 분할하는 기법입니다. TP를 사용하면 world_size(전체 GPU 수)와 dp_world_size(데이터 병렬 GPU 수)가 달라집니다. 예를 들어 4 GPU에서 TP=2이면 dp_world_size는 2입니다. 이 PR은 batch_size와 total_num_steps 계산에서 world_size 대신 dp_world_size를 사용하도록 수정합니다.

핵심 코드 분석

1. batch_size 계산에 tensor_parallel_size 반영

Before:

effective_world_size = cfg.world_size // (cfg.context_parallel_size or 1)
cfg.batch_size = cfg.batch_size * effective_world_size

After:

effective_world_size = (
    cfg.world_size
    // (cfg.context_parallel_size or 1)
    // (cfg.tensor_parallel_size or 1)
)
cfg.batch_size = cfg.batch_size * effective_world_size

TP GPU는 같은 데이터를 처리하므로, batch_size 스케일링에서 제외해야 합니다.

2. total_num_steps에서 tensor_parallel_size 곱셈 제거

Before:

total_num_steps = int(
    math.ceil(len(train_dataset) * cfg.num_epochs
              * cfg.tensor_parallel_size / cfg.batch_size)
)

After:

total_num_steps = int(
    math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)

기존 코드는 TP size를 total_num_steps에 곱해서 보상하려 했지만, 이는 근본적으로 잘못된 접근이었습니다. batch_size 자체를 올바르게 계산하면 이런 보상이 필요 없습니다.

3. 파라미터화된 테스트 추가

@pytest.mark.parametrize(
    "world_size, tensor_parallel_size, expected_batch_size",
    [
        (4, 1, 32),  # no TP: 2*4*4 = 32
        (4, 2, 16),  # TP=2: 2*4*(4//2) = 16
        (4, 4, 8),   # TP=4: 2*4*(4//4) = 8
        (2, 2, 8),   # TP=ws: 2*4*(2//2) = 8
    ],
)
def test_batch_size_with_tensor_parallelism(self, ...):
    ...
    assert cfg.batch_size == expected_batch_size

왜 이게 좋은가

이 버그는 TP 사용 시 학습이 의도보다 더 많거나 적은 step을 실행하고, batch_size가 과대 계산되어 OOM이나 학습 불안정을 유발할 수 있었습니다. 수정의 핵심 통찰은 "batch_size는 데이터 병렬 차원에서만 스케일해야 한다"는 것입니다. world_size에서 CP와 TP를 모두 나누어 순수 dp_world_size를 구하는 공식이 명확하고, 테스트가 경계 조건(TP=world_size)까지 검증합니다.

정리

항목 내용
문제 TP 환경에서 batch_size/total_num_steps 과대 계산
해결 effective_world_size = world_size // CP // TP
테스트 4가지 TP 조합에 대한 파라미터화 테스트

참고 자료

알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글