본문으로 건너뛰기

[Triton] WGMMA register pipelining에서 누락된 wait 삽입 수정

PR 링크: triton-lang/triton#8964 상태: Merged | 변경: +122 / -25

들어가며

NVIDIA Hopper 아키텍처의 WGMMA(Warp Group Matrix Multiply-Accumulate) 명령어는 비동기로 실행된다. 여러 WGMMA를 파이프라이닝할 때, accumulator 레지스터에 접근하기 전에 반드시 warp_group_dot_wait를 삽입해야 한다. 이 PR은 rsDotNeedsWait 로직이 일반적인 wait 삽입 경로를 우회하면서, persistent matmul의 epilogue처럼 accumulator를 실제로 읽는 경우에 wait가 누락되는 버그를 수정한다.

핵심 코드 분석

Before

if (rsDotNeedsWait(asyncDot, forOp)) {
  // rs-dot용 wait만 삽입하고 continue
  OpBuilder builder(asyncDot);
  builder.setInsertionPointAfter(asyncDot);
  auto newWait = ttng::WarpGroupDotWaitOp::create(...);
  threadValuesThroughWait(newWait, waitOperands);
  continue;  // 일반 경로를 완전히 건너뜀
}
// 일반 wait 삽입 로직 (rs-dot는 여기 도달하지 못함)

After

// 1단계: 모든 async dot에 대해 일반 wait 삽입
for (auto asyncDot : properlyAsyncDots) {
  // accumulator를 쓰는 wgmma는 하드웨어가 암묵적으로 파이프라이닝
  auto firstUse = std::find_if_not(uses.begin(), uses.end(),
    [](OpOperand *operand) {
      return (isa<ttng::WarpGroupDotOp>(operand->getOwner()) &&
              operand->getOperandNumber() == 2);
    });
  if (firstUse == uses.end()) continue;
  // 첫 번째 비-accumulator 사용 앞에 wait 삽입
  OpBuilder builder((*firstUse)->getOwner());
  auto newWait = ttng::WarpGroupDotWaitOp::create(...);
}

// 2단계: rs-dot용 wait는 별도로 추가
for (auto asyncDot : properlyAsyncDots) {
  if (!rsDotNeedsWait(asyncDot, forOp)) continue;
  // rs-dot 전용 wait 삽입
}

왜 이게 좋은가

  • 관심사 분리: rs-dot wait 로직과 일반 wait 로직이 독립적으로 동작하여, 한쪽이 다른 쪽을 차단하지 않는다.
  • 하드웨어 파이프라이닝 활용: accumulator를 공유하는 WGMMA끼리는 하드웨어가 자동으로 파이프라이닝하므로 불필요한 wait를 삽입하지 않는다.
  • 정확성: persistent matmul epilogue에서 scf.if 내의 accumulator 접근에 올바르게 wait {pendings = 0}이 삽입된다.

정리

이 PR은 두 가지 독립적인 동기화 요구사항(rs-dot용 wait와 일반 accumulator 접근용 wait)을 분리하여, 어떤 경우에도 필요한 wait가 누락되지 않도록 보장한다. 특히 persistent matmul처럼 epilogue에서 accumulator를 읽는 패턴에서 정확성 문제를 해결한다.

참고 자료


이 글은 AI 도구의 도움을 받아 작성되었습니다.

댓글

관련 포스트

PR Analysis 의 다른글