본문으로 건너뛰기

[triton] AMD Batched WMMA Scaled에서 스케일 레이아웃 수정

PR 링크: triton-lang/triton#9545 상태: Merged | 변경: +98 / -10

들어가며

AMD gfx1250의 WMMA scaled 명령어는 블록 스케일링이 적용된 행렬 곱셈을 가속합니다. 이 PR은 batched(3D) 연산에서 스케일 텐서의 K/nonK 차원과 batch 차원이 올바르게 매핑되지 않던 버그를 수정합니다.

핵심 코드 분석

Before

SmallVector<int32_t> order;
if (rank == 3) {
  order = {1, 0, 2};  // 하드코딩된 차원 순서
} else {
  order = {1, 0};
}
auto dimK = outDimNames[order[0]];
auto dimNonK = outDimNames[order[1]];

After

bool hasBatchDim = rank == 3;
// rank에 관계없이 마지막 두 차원이 K와 nonK
auto dimK = outDimNames[rank - 1];
auto dimNonK = outDimNames[rank - 2];

if (hasBatchDim) {
  tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[0]);
  tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[0]);
}

if (dotOperandIdx == 1) {
  // B operand의 스케일은 [N, K/32] 순서이므로 마지막 두 차원만 swap
  SmallVector<int32_t> order = {1, 0};
  if (hasBatchDim)
    order = {0, 2, 1};  // batch 차원은 유지, K/N만 swap
  ctaLayout = transposeLinearLayout(ctaLayout, order);
}

왜 이게 좋은가

  1. 정확한 차원 매핑: 2D와 3D에서 K/nonK 차원을 일관되게 rank-1/rank-2 인덱스로 참조합니다.
  2. Batch 차원 독립성: batch 차원(dim 0)은 transpose에서 제외하여 스케일 레이아웃이 올바르게 생성됩니다.
  3. 테스트 추가: test_amd_wmma_scaled_batched 테스트로 batched 경로의 정합성을 검증합니다.

정리

Batched matmul에서 스케일 레이아웃의 차원 순서 버그를 수정한 PR입니다. 하드코딩된 인덱스를 rank 기반의 상대적 인덱싱으로 바꾸어 2D/3D 모두 올바르게 동작하게 만들었습니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글