본문으로 건너뛰기

[Triton] Expert Parallelism 커널 버그 수정

들어가며

Triton의 multi-GPU expert parallelism 구현에서 분산 테스트 인프라와 routing 로직에 버그가 있었다. 이 PR은 테스트를 torchrun 의존 없이 mp.spawn으로 실행 가능하게 만들고, routing 커널의 rank 계산 버그를 수정한다.

핵심 코드 분석

Before (routing 버그)

expt_rank = tl.sum(offs_r[:, None] * expt_filter, axis=0) > 0

rank를 boolean으로 변환하여 0 또는 1만 반환했다. multi-GPU 환경에서 실제 rank 번호가 필요한데 True/False만 얻게 되는 버그였다.

After

expt_rank = tl.sum(offs_r[:, None] * expt_filter, axis=0)

> 0 비교를 제거하여 실제 rank 값을 보존한다.

Before (테스트 인프라)

@pytest.fixture(scope="session", autouse=True)
def init_distributed():
    if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
        pytest.skip("Launch with torchrun")

torchrun으로만 실행 가능하여 CI에서 사용이 불편했다.

After

def _distributed_worker(rank, fn, world_size, kwargs):
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size, ...)
    fn(rank=rank, world_size=world_size, **kwargs)

def launch(fn, **kwargs):
    mp.spawn(_distributed_worker, args=(fn, n_gpus, kwargs), nprocs=n_gpus, join=True)

mp.spawn 기반으로 변경하여 일반 pytest에서도 multi-GPU 테스트가 가능해졌다.

왜 이게 좋은가

  • 정확한 rank 계산: 분산 환경에서 expert가 올바른 GPU로 라우팅된다.
  • 테스트 접근성: torchrun 없이 pytest로 직접 실행 가능하여 CI 통합이 쉬워졌다.
  • 불필요한 동기화 제거: torch.cuda.synchronize()와 중복 barrier 호출이 정리되었다.

정리

분산 시스템의 작은 타입 변환 실수(> 0)가 routing 전체를 망가뜨릴 수 있다는 교훈과 함께, 테스트 인프라 개선이 버그 발견을 용이하게 한다는 점을 보여주는 PR이다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글