본문으로 건너뛰기

[Triton] ReduceOp 로우어링을 LinearLayout 기반으로 개선 및 단순화

PR 링크: triton-lang/triton#9192 상태: Merged | 변경: +511 / -424

들어가며

Triton의 ReduceOp(텐서의 특정 축을 따라 값을 축소하는 연산, 예: sum, max)의 로우어링은 기존에 복잡한 수동 인덱싱 로직으로 구현되어 있었다. 이 PR은 Triton의 LinearLayout 인프라를 활용하여 로우어링을 재설계한다. ConvertLayout과 LinearLayout 헬퍼가 대부분의 작업을 처리하므로, 코드가 대폭 단순해진다.

핵심 코드 분석

Before (Utility.h)

class ReduceOpHelper {
  ArrayRef<int64_t> getSrcShape();
  Attribute getSrcLayout();
  triton::ReduceOp getOperation();
  unsigned getThreadOffsetOnReductionAxis();
  SmallVector<unsigned> getScratchRepShape();
  SmallVector<unsigned> getOrderWithAxisAtBeginning();
  unsigned getScratchSizeInBytes();
  // ... 복잡한 인덱싱 로직
};

After (Utility.h)

class ReduceOpHelper {
  RankedTensorType getSrcTy();
  bool isWarpSynchronous();
  bool isReduceWithinCTA();
  bool isAssociative();

  static ColumnAction makeAxisContiguous(
      const LinearLayout &layout, int axis);
  static LinearLayout zeroBasesAlongDimAndReorder(
      const LinearLayout &layout, unsigned axis, StringAttr dim);
  static LinearLayout getInterLayout(
      const LinearLayout &layout, unsigned axis);
  static LinearLayout reducedRegLaneLayout(
      RankedTensorType srcTy, unsigned axis);

  SmallVector<unsigned> getScratchBytesForCvt(
      const LinearLayout &srcLayout,
      const LinearLayout &dstLayout);
};

왜 이게 좋은가

  • 코드 단순화: ConvertLayout이 복잡한 shared memory 접근 패턴을 자동 처리하므로, ReduceOp 구현에서 수동 인덱싱이 사라진다.
  • Shmem swizzling 무료 제공: LinearLayout이 swizzling을 자동 적용하여, bank conflict 없는 shared memory 접근이 보장된다.
  • 불필요한 round-trip 제거: 기존에는 무조건 shmem round-trip을 수행했지만, 새 구현에서는 warp-synchronous 경우 생략할 수 있다.
  • LOC 중립: 511줄 추가, 424줄 삭제로, 기능은 확장되면서 코드량은 비슷하다.

정리

LinearLayout이라는 추상화를 활용하면, 복잡한 GPU 메모리 접근 패턴을 일관된 방식으로 표현할 수 있다. 이 PR은 잘 설계된 추상화가 코드 단순화와 성능 개선을 동시에 달성할 수 있음을 보여주는 좋은 사례다.

참고 자료


이 글은 AI 도구의 도움을 받아 작성되었습니다.

댓글

관련 포스트

PR Analysis 의 다른글