[triton] AMD Batched WMMA Scaled에서 스케일 레이아웃 수정
PR 링크: triton-lang/triton#9545 상태: Merged | 변경: +98 / -10
들어가며
AMD gfx1250의 WMMA scaled 명령어는 블록 스케일링이 적용된 행렬 곱셈을 가속합니다. 이 PR은 batched(3D) 연산에서 스케일 텐서의 K/nonK 차원과 batch 차원이 올바르게 매핑되지 않던 버그를 수정합니다.
핵심 코드 분석
Before
SmallVector<int32_t> order;
if (rank == 3) {
order = {1, 0, 2}; // 하드코딩된 차원 순서
} else {
order = {1, 0};
}
auto dimK = outDimNames[order[0]];
auto dimNonK = outDimNames[order[1]];
After
bool hasBatchDim = rank == 3;
// rank에 관계없이 마지막 두 차원이 K와 nonK
auto dimK = outDimNames[rank - 1];
auto dimNonK = outDimNames[rank - 2];
if (hasBatchDim) {
tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[0]);
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[0]);
}
if (dotOperandIdx == 1) {
// B operand의 스케일은 [N, K/32] 순서이므로 마지막 두 차원만 swap
SmallVector<int32_t> order = {1, 0};
if (hasBatchDim)
order = {0, 2, 1}; // batch 차원은 유지, K/N만 swap
ctaLayout = transposeLinearLayout(ctaLayout, order);
}
왜 이게 좋은가
- 정확한 차원 매핑: 2D와 3D에서 K/nonK 차원을 일관되게
rank-1/rank-2인덱스로 참조합니다. - Batch 차원 독립성: batch 차원(dim 0)은 transpose에서 제외하여 스케일 레이아웃이 올바르게 생성됩니다.
- 테스트 추가:
test_amd_wmma_scaled_batched테스트로 batched 경로의 정합성을 검증합니다.
정리
Batched matmul에서 스케일 레이아웃의 차원 순서 버그를 수정한 PR입니다. 하드코딩된 인덱스를 rank 기반의 상대적 인덱싱으로 바꾸어 2D/3D 모두 올바르게 동작하게 만들었습니다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Ray] Dashboard 죽은 노드 캐시의 변수 섀도잉 버그 수정
- 현재글 : [triton] AMD Batched WMMA Scaled에서 스케일 레이아웃 수정
- 다음글 [Grafana Loki] 쿼리 엔진 aggregator의 자료구조를 개선하여 38% 성능 향상
댓글