[Triton] ReduceOp 로우어링을 LinearLayout 기반으로 개선 및 단순화
PR 링크: triton-lang/triton#9192 상태: Merged | 변경: +511 / -424
들어가며
Triton의 ReduceOp(텐서의 특정 축을 따라 값을 축소하는 연산, 예: sum, max)의 로우어링은 기존에 복잡한 수동 인덱싱 로직으로 구현되어 있었다. 이 PR은 Triton의 LinearLayout 인프라를 활용하여 로우어링을 재설계한다. ConvertLayout과 LinearLayout 헬퍼가 대부분의 작업을 처리하므로, 코드가 대폭 단순해진다.
핵심 코드 분석
Before (Utility.h)
class ReduceOpHelper {
ArrayRef<int64_t> getSrcShape();
Attribute getSrcLayout();
triton::ReduceOp getOperation();
unsigned getThreadOffsetOnReductionAxis();
SmallVector<unsigned> getScratchRepShape();
SmallVector<unsigned> getOrderWithAxisAtBeginning();
unsigned getScratchSizeInBytes();
// ... 복잡한 인덱싱 로직
};
After (Utility.h)
class ReduceOpHelper {
RankedTensorType getSrcTy();
bool isWarpSynchronous();
bool isReduceWithinCTA();
bool isAssociative();
static ColumnAction makeAxisContiguous(
const LinearLayout &layout, int axis);
static LinearLayout zeroBasesAlongDimAndReorder(
const LinearLayout &layout, unsigned axis, StringAttr dim);
static LinearLayout getInterLayout(
const LinearLayout &layout, unsigned axis);
static LinearLayout reducedRegLaneLayout(
RankedTensorType srcTy, unsigned axis);
SmallVector<unsigned> getScratchBytesForCvt(
const LinearLayout &srcLayout,
const LinearLayout &dstLayout);
};
왜 이게 좋은가
- 코드 단순화: ConvertLayout이 복잡한 shared memory 접근 패턴을 자동 처리하므로, ReduceOp 구현에서 수동 인덱싱이 사라진다.
- Shmem swizzling 무료 제공: LinearLayout이 swizzling을 자동 적용하여, bank conflict 없는 shared memory 접근이 보장된다.
- 불필요한 round-trip 제거: 기존에는 무조건 shmem round-trip을 수행했지만, 새 구현에서는 warp-synchronous 경우 생략할 수 있다.
- LOC 중립: 511줄 추가, 424줄 삭제로, 기능은 확장되면서 코드량은 비슷하다.
정리
LinearLayout이라는 추상화를 활용하면, 복잡한 GPU 메모리 접근 패턴을 일관된 방식으로 표현할 수 있다. 이 PR은 잘 설계된 추상화가 코드 단순화와 성능 개선을 동시에 달성할 수 있음을 보여주는 좋은 사례다.
참고 자료
이 글은 AI 도구의 도움을 받아 작성되었습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [pytorch] CI: fbgemm/torchrec 핀 버전 업데이트 및 빌드 로직 리팩토링
- 현재글 : [Triton] ReduceOp 로우어링을 LinearLayout 기반으로 개선 및 단순화
- 다음글 [pytorch] CI: Inductor 테스트에 IoU 기반 accuracy 체크를 추가하여 segmentation 모델 안정화
댓글