[triton] AMD: WMMA layout의 CTA 필드를 LinearLayout으로 일반화하여 swizzled warp 레이아웃 지원
PR 링크: triton-lang/triton#8946 상태: Merged | 변경: +901 / -767
들어가며
AMD GPU의 WMMA(Wave Matrix Multiply-Accumulate) layout은 warp가 WMMA 타일에 어떻게 매핑되는지를 정의합니다. 기존에는 warpsPerCTA와 tilesPerWarp 두 정수 배열로 표현했는데, 이로는 gfx1250에서 LDS 파티션 충돌을 줄이기 위한 swizzled warp 레이아웃을 표현할 수 없었습니다.
핵심 코드 분석
Before: warpsPerCTA + tilesPerWarp
// TritonGPUAttrDefs.td
let parameters = (ins
"unsigned": $version,
"bool":$isTransposed,
ArrayRefParameter<"unsigned">:$warpsPerCTA,
ArrayRefParameter<"unsigned">:$tilesPerWarp,
"CGAEncodingAttr":$CGALayout,
ArrayRefParameter<"unsigned">:$instrShape
);
이 방식은 warpsPerCTA = [2, 2], tilesPerWarp = [1, 1]로 다음 배치만 표현 가능:
w0 w1 w0 w1
w2 w3 w2 w3
After: LinearLayout 기반 ctaLayout
let parameters = (ins
"unsigned": $version,
LinearLayoutParam:$ctaLayout, // warpsPerCTA+tilesPerWarp 대체
"bool":$isTransposed,
"CGAEncodingAttr":$CGALayout,
ArrayRefParameter<"unsigned">:$instrShape
);
swizzled 레이아웃 표현 가능:
ctaLayout = {reg = [[2, 0]], warps = [[2, 1], [1, 0]]}
// 결과:
w0 w1 <- w1의 두 번째 타일
w2 w3
w0 w1 <- w1의 첫 번째 타일
w2 w3
최종 WMMA layout은 타일 내부 layout과 CTA layout의 합성으로 단순화됩니다:
wmmaLayout = tileLayout * ctaLayout
왜 이게 좋은가
- 표현력 확장: swizzled warp 레이아웃으로 LDS 파티션 충돌을 줄일 수 있습니다.
- lowering 단순화: tileLayout과 ctaLayout의 합성으로 WMMA/dotOperand lowering이 단순해졌습니다.
- 하위 호환: 기존 warpsPerCTA/tilesPerWarp 조합은 ctaLayout으로 자동 변환됩니다.
정리
레이아웃 표현을 LinearLayout으로 일반화하는 것은 Triton 컴파일러의 핵심 설계 방향입니다. 이 PR은 WMMA layout에 이를 적용하여 새로운 아키텍처의 요구사항을 수용하면서도 기존 코드를 단순화했습니다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었으며, PR의 실제 diff를 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [vllm] 비동기 스케줄링 기본 활성화로 GPU 유휴 시간 제거
- 현재글 : [triton] AMD: WMMA layout의 CTA 필드를 LinearLayout으로 일반화하여 swizzled warp 레이아웃 지원
- 다음글 [triton] AMD ReorderInstructions에서 효과 없는 sinkSecondLoad 최적화 제거
댓글