본문으로 건너뛰기

[triton] NVIDIA inval_barrier를 leader CTA에서만 실행하도록 변경

PR 링크: triton-lang/triton#9582 상태: Merged | 변경: +25 / -3

들어가며

NVIDIA의 mbarrier.inval 명령어는 mbarrier 객체를 무효화합니다. Multi-CTA 클러스터에서 barrier가 브로드캐스트된 경우(여러 CTA가 공유하는 barrier), inval은 barrier를 소유한 leader CTA에서만 실행해야 합니다.

핵심 코드 분석

Before:

Value pred = getElectWarp0OrThread0(*targetInfo, b);
// 모든 CTA에서 실행
ptxBuilder.create(ptx);
barSyncOp({ptxBuilder.newOperand(pred, "b"),
           ptxBuilder.newOperand(smemObj.getBase(), "r")},
          /*onlyAttachMLIRArgs=*/true);

After:

Value pred = getElectWarp0OrThread0(*targetInfo, b);
if (auto leaderPred =
        LLVM::NVIDIA::getLeaderCTAPredicate(loc, rewriter, barrierTy))
  pred = b.and_(pred, *leaderPred);
Value barrierPtr = LLVM::NVIDIA::getLeaderAddress(
    loc, rewriter, smemObj.getBase(), barrierTy);
// leader CTA에서만 실행
barSyncOp({ptxBuilder.newOperand(pred, "b"),
           ptxBuilder.newOperand(barrierPtr, "r")},
          /*onlyAttachMLIRArgs=*/true);
  1. getLeaderCTAPredicate: 현재 CTA가 barrier의 소유자(leader)인지 확인하는 predicate를 생성합니다.
  2. getLeaderAddress: Barrier의 물리적 주소를 leader CTA의 shared memory 주소로 변환합니다.
  3. 두 predicate를 AND하여 leader CTA의 대표 스레드만 mbarrier.inval을 실행합니다.

왜 이게 좋은가

  • 정확성: Broadcasted barrier를 여러 CTA에서 동시에 invalidate하면 undefined behavior가 발생할 수 있습니다. Leader만 실행하여 이를 방지합니다.
  • 간결한 구현: 기존 getLeaderCTAPredicate/getLeaderAddress 유틸리티를 재활용합니다.

정리

Multi-CTA에서 broadcasted mbarrier의 invalidation을 leader CTA에서만 수행하도록 predicate와 주소를 올바르게 설정한 간결한 수정입니다.

참고 자료


이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글