[Triton] WGMMA register pipelining에서 누락된 wait 삽입 수정
PR 링크: triton-lang/triton#8964 상태: Merged | 변경: +122 / -25
들어가며
NVIDIA Hopper 아키텍처의 WGMMA(Warp Group Matrix Multiply-Accumulate) 명령어는 비동기로 실행된다. 여러 WGMMA를 파이프라이닝할 때, accumulator 레지스터에 접근하기 전에 반드시 warp_group_dot_wait를 삽입해야 한다. 이 PR은 rsDotNeedsWait 로직이 일반적인 wait 삽입 경로를 우회하면서, persistent matmul의 epilogue처럼 accumulator를 실제로 읽는 경우에 wait가 누락되는 버그를 수정한다.
핵심 코드 분석
Before
if (rsDotNeedsWait(asyncDot, forOp)) {
// rs-dot용 wait만 삽입하고 continue
OpBuilder builder(asyncDot);
builder.setInsertionPointAfter(asyncDot);
auto newWait = ttng::WarpGroupDotWaitOp::create(...);
threadValuesThroughWait(newWait, waitOperands);
continue; // 일반 경로를 완전히 건너뜀
}
// 일반 wait 삽입 로직 (rs-dot는 여기 도달하지 못함)
After
// 1단계: 모든 async dot에 대해 일반 wait 삽입
for (auto asyncDot : properlyAsyncDots) {
// accumulator를 쓰는 wgmma는 하드웨어가 암묵적으로 파이프라이닝
auto firstUse = std::find_if_not(uses.begin(), uses.end(),
[](OpOperand *operand) {
return (isa<ttng::WarpGroupDotOp>(operand->getOwner()) &&
operand->getOperandNumber() == 2);
});
if (firstUse == uses.end()) continue;
// 첫 번째 비-accumulator 사용 앞에 wait 삽입
OpBuilder builder((*firstUse)->getOwner());
auto newWait = ttng::WarpGroupDotWaitOp::create(...);
}
// 2단계: rs-dot용 wait는 별도로 추가
for (auto asyncDot : properlyAsyncDots) {
if (!rsDotNeedsWait(asyncDot, forOp)) continue;
// rs-dot 전용 wait 삽입
}
왜 이게 좋은가
- 관심사 분리: rs-dot wait 로직과 일반 wait 로직이 독립적으로 동작하여, 한쪽이 다른 쪽을 차단하지 않는다.
- 하드웨어 파이프라이닝 활용: accumulator를 공유하는 WGMMA끼리는 하드웨어가 자동으로 파이프라이닝하므로 불필요한 wait를 삽입하지 않는다.
- 정확성: persistent matmul epilogue에서
scf.if내의 accumulator 접근에 올바르게wait {pendings = 0}이 삽입된다.
정리
이 PR은 두 가지 독립적인 동기화 요구사항(rs-dot용 wait와 일반 accumulator 접근용 wait)을 분리하여, 어떤 경우에도 필요한 wait가 누락되지 않도록 보장한다. 특히 persistent matmul처럼 epilogue에서 accumulator를 읽는 패턴에서 정확성 문제를 해결한다.
참고 자료
이 글은 AI 도구의 도움을 받아 작성되었습니다.
관련 포스트
- [Triton] SWP 루프 로우어링에서 barrier 위치 결정 로직 수정
- [triton] AMD Canonicalize Pointers에서 arith.select의 비대칭 fat pointer 처리 강화
- [Triton] FenceAsync에 비동기 읽기 의존성 추가 — st.shared와 copy_local_to_global 간 정합성 보장
- [Triton] Blackwell 2D activation-scale layout에서 ragged metadata 없이 동작하도록 수정
- [triton] NVIDIA TMA im2col 모드 Tensor Descriptor 지원
PR Analysis 의 다른글
- 이전글 [Triton] MXFP4→BF16 변환에서 mul.bf16x2 강제 사용 — 1% MoE 성능 향상
- 현재글 : [Triton] WGMMA register pipelining에서 누락된 wait 삽입 수정
- 다음글 [Triton] ConSan에 버퍼 aliasing 지원 추가 — 메모리 안전성 분석 강화
댓글