본문으로 건너뛰기

[triton] AMD: WMMA layout의 CTA 필드를 LinearLayout으로 일반화하여 swizzled warp 레이아웃 지원

PR 링크: triton-lang/triton#8946 상태: Merged | 변경: +901 / -767

들어가며

AMD GPU의 WMMA(Wave Matrix Multiply-Accumulate) layout은 warp가 WMMA 타일에 어떻게 매핑되는지를 정의합니다. 기존에는 warpsPerCTAtilesPerWarp 두 정수 배열로 표현했는데, 이로는 gfx1250에서 LDS 파티션 충돌을 줄이기 위한 swizzled warp 레이아웃을 표현할 수 없었습니다.

핵심 코드 분석

Before: warpsPerCTA + tilesPerWarp

// TritonGPUAttrDefs.td
let parameters = (ins
  "unsigned": $version,
  "bool":$isTransposed,
  ArrayRefParameter<"unsigned">:$warpsPerCTA,
  ArrayRefParameter<"unsigned">:$tilesPerWarp,
  "CGAEncodingAttr":$CGALayout,
  ArrayRefParameter<"unsigned">:$instrShape
);

이 방식은 warpsPerCTA = [2, 2], tilesPerWarp = [1, 1]로 다음 배치만 표현 가능:

w0 w1 w0 w1
w2 w3 w2 w3

After: LinearLayout 기반 ctaLayout

let parameters = (ins
  "unsigned": $version,
  LinearLayoutParam:$ctaLayout,  // warpsPerCTA+tilesPerWarp 대체
  "bool":$isTransposed,
  "CGAEncodingAttr":$CGALayout,
  ArrayRefParameter<"unsigned">:$instrShape
);

swizzled 레이아웃 표현 가능:

ctaLayout = {reg = [[2, 0]], warps = [[2, 1], [1, 0]]}
// 결과:
w0 w1    <- w1의 두 번째 타일
w2 w3
w0 w1    <- w1의 첫 번째 타일
w2 w3

최종 WMMA layout은 타일 내부 layout과 CTA layout의 합성으로 단순화됩니다:

wmmaLayout = tileLayout * ctaLayout

왜 이게 좋은가

  1. 표현력 확장: swizzled warp 레이아웃으로 LDS 파티션 충돌을 줄일 수 있습니다.
  2. lowering 단순화: tileLayout과 ctaLayout의 합성으로 WMMA/dotOperand lowering이 단순해졌습니다.
  3. 하위 호환: 기존 warpsPerCTA/tilesPerWarp 조합은 ctaLayout으로 자동 변환됩니다.

정리

레이아웃 표현을 LinearLayout으로 일반화하는 것은 Triton 컴파일러의 핵심 설계 방향입니다. 이 PR은 WMMA layout에 이를 적용하여 새로운 아키텍처의 요구사항을 수용하면서도 기존 코드를 단순화했습니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, PR의 실제 diff를 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글