[triton] Reduce 커널에 Unpadded Batch Size 핸들링 추가
PR 링크: triton-lang/triton#9332 상태: Merged | 변경: +32 / -6
들어가며
MoE(Mixture of Experts) 워크로드에서 expert별 토큰 수가 다를 때, 배치를 최대 크기로 패딩하면 실제 데이터가 없는 행에 대해서도 reduce 연산을 수행하는 낭비가 발생합니다. 이 PR은 unpadded_batch_size 파라미터를 추가하여 실제 배치 크기만큼만 연산하도록 합니다.
핵심 코드 분석
Before
valid_s0 = offs_s0 < S0 # S0은 패딩된 전체 배치 크기
After
if UnpaddedBatchSize is not None:
unpadded = tl.load(UnpaddedBatchSize).to(tl.int32)
if pid_s0 * BLOCK_S0 >= unpadded:
return # 패딩 영역은 전체 프로그램 ID 단위로 스킵
valid_s0 = offs_s0 < unpadded
else:
valid_s0 = offs_s0 < S0
Python API
def reduce(x, dim, mask=None, scale=None, ...,
unpadded_batch_size: Optional[torch.Tensor] = None):
"""Optional single-element tensor specifying the number of entries
to reduce along the first dimension."""
왜 이게 좋은가
- Early exit: 패딩 영역에 해당하는 프로그램 ID는 즉시 반환하여 GPU 자원을 절약합니다.
- 동적 배치:
unpadded_batch_size가 텐서이므로 런타임에 배치 크기가 변해도 재컴파일 없이 동작합니다. - 역전파 지원: forward와 backward 모두에 동일한 로직이 적용됩니다.
정리
MoE 워크로드에서 패딩된 배치의 불필요한 reduce 연산을 제거하는 최적화입니다. 텐서 기반의 동적 배치 크기 전달로 재컴파일 없이 유연하게 동작합니다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [uvloop] uvloop의 SSL 성능 최적화: SSLWantReadError 비용 줄이기
- 현재글 : [triton] Reduce 커널에 Unpadded Batch Size 핸들링 추가
- 다음글 [triton] Triton AMD GPU 백엔드: v_perm 명령어를 활용한 레이아웃 변환 최적화
댓글