[Triton] Expert Parallelism 커널 버그 수정
들어가며
Triton의 multi-GPU expert parallelism 구현에서 분산 테스트 인프라와 routing 로직에 버그가 있었다. 이 PR은 테스트를 torchrun 의존 없이 mp.spawn으로 실행 가능하게 만들고, routing 커널의 rank 계산 버그를 수정한다.
핵심 코드 분석
Before (routing 버그)
expt_rank = tl.sum(offs_r[:, None] * expt_filter, axis=0) > 0
rank를 boolean으로 변환하여 0 또는 1만 반환했다. multi-GPU 환경에서 실제 rank 번호가 필요한데 True/False만 얻게 되는 버그였다.
After
expt_rank = tl.sum(offs_r[:, None] * expt_filter, axis=0)
> 0 비교를 제거하여 실제 rank 값을 보존한다.
Before (테스트 인프라)
@pytest.fixture(scope="session", autouse=True)
def init_distributed():
if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
pytest.skip("Launch with torchrun")
torchrun으로만 실행 가능하여 CI에서 사용이 불편했다.
After
def _distributed_worker(rank, fn, world_size, kwargs):
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size, ...)
fn(rank=rank, world_size=world_size, **kwargs)
def launch(fn, **kwargs):
mp.spawn(_distributed_worker, args=(fn, n_gpus, kwargs), nprocs=n_gpus, join=True)
mp.spawn 기반으로 변경하여 일반 pytest에서도 multi-GPU 테스트가 가능해졌다.
왜 이게 좋은가
- 정확한 rank 계산: 분산 환경에서 expert가 올바른 GPU로 라우팅된다.
- 테스트 접근성:
torchrun없이pytest로 직접 실행 가능하여 CI 통합이 쉬워졌다. - 불필요한 동기화 제거:
torch.cuda.synchronize()와 중복 barrier 호출이 정리되었다.
정리
분산 시스템의 작은 타입 변환 실수(> 0)가 routing 전체를 망가뜨릴 수 있다는 교훈과 함께, 테스트 인프라 개선이 버그 발견을 용이하게 한다는 점을 보여주는 PR이다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글