본문으로 건너뛰기

[triton] Reduce 커널에 Unpadded Batch Size 핸들링 추가

PR 링크: triton-lang/triton#9332 상태: Merged | 변경: +32 / -6

들어가며

MoE(Mixture of Experts) 워크로드에서 expert별 토큰 수가 다를 때, 배치를 최대 크기로 패딩하면 실제 데이터가 없는 행에 대해서도 reduce 연산을 수행하는 낭비가 발생합니다. 이 PR은 unpadded_batch_size 파라미터를 추가하여 실제 배치 크기만큼만 연산하도록 합니다.

핵심 코드 분석

Before

valid_s0 = offs_s0 < S0  # S0은 패딩된 전체 배치 크기

After

if UnpaddedBatchSize is not None:
    unpadded = tl.load(UnpaddedBatchSize).to(tl.int32)
    if pid_s0 * BLOCK_S0 >= unpadded:
        return  # 패딩 영역은 전체 프로그램 ID 단위로 스킵
    valid_s0 = offs_s0 < unpadded
else:
    valid_s0 = offs_s0 < S0

Python API

def reduce(x, dim, mask=None, scale=None, ...,
           unpadded_batch_size: Optional[torch.Tensor] = None):
    """Optional single-element tensor specifying the number of entries
    to reduce along the first dimension."""

왜 이게 좋은가

  1. Early exit: 패딩 영역에 해당하는 프로그램 ID는 즉시 반환하여 GPU 자원을 절약합니다.
  2. 동적 배치: unpadded_batch_size가 텐서이므로 런타임에 배치 크기가 변해도 재컴파일 없이 동작합니다.
  3. 역전파 지원: forward와 backward 모두에 동일한 로직이 적용됩니다.

정리

MoE 워크로드에서 패딩된 배치의 불필요한 reduce 연산을 제거하는 최적화입니다. 텐서 기반의 동적 배치 크기 전달로 재컴파일 없이 유연하게 동작합니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글