본문으로 건너뛰기

[triton] AMD BlockPingpong 패스의 non-MFMA dot 크래시 수정

PR 링크: triton-lang/triton#9618 상태: Merged | 변경: +79 / -1

들어가며

AMD GPU의 BlockPingpong 최적화는 MFMA(Matrix Fused Multiply-Add) 명령어 기반의 dot 연산에서 스케줄링을 개선하는 패스입니다. 그러나 FMA 기반 dot 연산(MFMA가 아닌 일반 행렬 곱)이 이 패스에 전달되면, MFMA encoding으로의 잘못된 cast로 인해 크래시가 발생했습니다.

핵심 코드 분석

Before:

void Pingponger::getDotPingponged() {
  auto encoding = cast<RankedTensorType>(aType).getEncoding();
  auto srcEncoding = cast<ttg::DotOperandEncodingAttr>(encoding);
  kWidth = srcEncoding.getKWidth();
  auto mfmaEncoding = cast<ttg::AMDMfmaEncodingAttr>(srcEncoding.getParent());
  // cast 실패 시 크래시
  SmallVector<int64_t> intShape;
  auto mnkDim = mfmaEncoding.getInstrShape();
  // ...
}

After:

void Pingponger::getDotPingponged() {
  auto encoding = cast<RankedTensorType>(aType).getEncoding();
  auto srcEncoding = cast<ttg::DotOperandEncodingAttr>(encoding);
  kWidth = srcEncoding.getKWidth();
  auto mfmaEncoding =
      dyn_cast<ttg::AMDMfmaEncodingAttr>(srcEncoding.getParent());
  if (!mfmaEncoding) {
    LDBG("Encountered non-MFMA layout");
    return;
  }
  SmallVector<int64_t> intShape;
  auto mnkDim = mfmaEncoding.getInstrShape();
  // ...
}

castdyn_cast로 변경하고 null 체크를 추가하여, MFMA가 아닌 인코딩(예: FMA)일 때 안전하게 조기 반환합니다. 테스트에서는 #fake_mma = #ttg.blocked<...> 인코딩을 사용한 FMA dot으로 이 경로를 검증합니다.

왜 이게 좋은가

MLIR 기반 컴파일러에서 cast vs dyn_cast의 올바른 사용은 매우 중요합니다. cast는 타입이 확실할 때만 사용해야 하며, 그렇지 않으면 즉시 크래시합니다. 이 PR은 방어적 프로그래밍의 좋은 사례로, 현재 BlockPingpong이 MFMA 전용임을 명시하면서도 다른 인코딩에서 안전하게 동작하도록 보장합니다. 테스트 케이스에 CHECK-NOT: rocdl.sched.barrierCHECK-NOT: rocdl.s.setprio를 포함하여 최적화가 적용되지 않음을 검증합니다.

정리

  • cast<AMDMfmaEncodingAttr>dyn_cast로 변경하여 non-MFMA 크래시 방지
  • FMA dot 연산에 대한 테스트 케이스 추가
  • BlockPingpong이 MFMA 전용임을 코드에 명시적으로 표현

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글