본문으로 건너뛰기

[triton] wgmma wait(0)를 accumulator 첫 사용 시점으로 지연하여 MMA-epilogue 오버랩 달성

PR 링크: triton-lang/triton#9021 상태: Merged | 변경: +45 / -28

들어가며

NVIDIA Hopper GPU의 wgmma(Warp Group Matrix Multiply Accumulate)는 비동기 행렬 곱셈 명령어입니다. 파이프라인된 루프 이후에는 wait(0)으로 모든 미완료 MMA가 완료될 때까지 대기해야 합니다. 기존에는 루프 직후에 wait를 삽입했지만, accumulator가 실제로 사용되기 전까지 다른 연산(epilogue)을 실행할 수 있는 기회를 놓치고 있었습니다.

핵심 코드 분석

Before:

// Wait until there are 0 outstanding async dot ops.
builder.setInsertionPointAfter(forOp);
auto WarpGroupDotWaitAfterLoop = ttng::WarpGroupDotWaitOp::create(
    builder, forOp.getLoc(), ArrayRef<Value>{}, 0);

After:

// Insert a wait(0) before the first use outside the loop
Operation *firstUse = nullptr;
for (auto accVal : waitOperands) {
  for (auto user : accVal.getUsers()) {
    auto target = curBlock->findAncestorOpInBlock(*user);
    if (!target) continue;
    if (!firstUse || target->isBeforeInBlock(firstUse))
      firstUse = target;
  }
}
if (firstUse) {
  builder.setInsertionPoint(firstUse);
} else {
  builder.setInsertionPoint(curBlock->getTerminator());
}
auto WarpGroupDotWaitAfterLoop = ttng::WarpGroupDotWaitOp::create(
    builder, forOp.getLoc(), ArrayRef<Value>{}, 0);

루프 결과(accumulator)의 모든 사용자를 스캔하여, 가장 먼저 사용하는 연산 직전에 wait를 삽입합니다. 루프와 첫 사용 사이에 있는 epilogue 연산(예: barrier invalidation, 메모리 해제)은 MMA와 병렬로 실행될 수 있습니다.

왜 이게 좋은가

이 최적화는 명령어 수준 병렬성(ILP)의 확대입니다. GPU는 비동기 MMA와 스칼라/메모리 연산을 동시에 실행할 수 있으므로, wait 시점을 늦추면 그만큼 오버랩 가능한 연산이 늘어납니다. non-persistent bf16 x mxfp4 MoE에서 0.4%의 안정적(repeatable) 속도 향상을 보였습니다.

정리

  • wait(0)을 루프 직후에서 accumulator 첫 사용 직전으로 이동
  • epilogue 연산(barrier inval, dealloc 등)이 MMA와 오버랩 가능
  • bf16 x mxfp4 MoE에서 0.4% 성능 향상

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글