본문으로 건너뛰기

[triton] AMD FpSan dot 에뮬레이션의 MFMA/WMMA encoding 호환성 수정

PR 링크: triton-lang/triton#9655 상태: Merged | 변경: +233 / -33

들어가며

Triton의 FP Sanitizer(FpSan)는 부동소수점 연산의 정확성을 검증하기 위해 dot 연산을 작은 타일 단위로 에뮬레이션합니다. 그러나 이 에뮬레이션 타일에 원래의 MFMA/WMMA 인코딩을 적용하면 최소 shape 요구사항(예: 16x16)을 충족하지 못하는 문제가 있었습니다.

핵심 코드 분석

Before:

auto accLayout = cast<ttg::DistributedEncodingTrait>(cTy.getEncoding());
auto aLayout = cast<ttg::DistributedEncodingTrait>(aTy.getEncoding());
auto bLayout = cast<ttg::DistributedEncodingTrait>(bTy.getEncoding());

After:

// 에뮬레이션 타일에 최적화된 blocked layout 사용
auto accLayout = getOptimizedBlockedEncoding(rewriter, {tileM, tileN},
                                             cTy.getElementType());
auto aLayout =
    getOptimizedBlockedEncoding(rewriter, {tileM, k}, aTy.getElementType());
auto bLayout =
    getOptimizedBlockedEncoding(rewriter, {k, tileN}, bTy.getElementType());

// cross-warp barrier 추가
ttg::BarrierOp::create(rewriter, loc,
                       ttg::AddrSpace::GlobalRead | ttg::AddrSpace::GlobalWrite);

에뮬레이션 타일(kTileM=8, kTileN=8)은 AMDWmmaEncodingAttr의 최소 shape(16x16) 요구사항을 만족하지 못합니다. 이를 범용 BlockedEncodingAttr로 교체하여 임의 크기의 타일에서도 동작하도록 했습니다. 또한 각 warp가 scratch 메모리의 일부만 쓰기 때문에, 루프 전후에 barrier를 삽입하여 모든 warp의 쓰기가 완료된 후 읽기가 수행되도록 보장합니다.

왜 이게 좋은가

FpSan이 올바르게 동작하려면 모든 dot 인코딩에서 에뮬레이션이 정확해야 합니다. 이 수정은 하드웨어 특화 인코딩과 범용 에뮬레이션 로직 사이의 추상화 경계를 올바르게 설정합니다. 또한 cross-warp 동기화 누락이라는 미묘한 정확성 문제를 barrier 삽입으로 해결하여, FpSan의 신뢰성을 크게 높였습니다.

정리

  • 에뮬레이션 타일에 getOptimizedBlockedEncoding으로 범용 layout 적용
  • MFMA/WMMA의 최소 shape 요구사항 문제 해결
  • Cross-warp scratch 가시성을 위한 barrier 삽입
  • MFMA/WMMA encoding에 대한 전용 테스트 추가

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글