[triton] NVIDIA inval_barrier를 leader CTA에서만 실행하도록 변경
PR 링크: triton-lang/triton#9582 상태: Merged | 변경: +25 / -3
들어가며
NVIDIA의 mbarrier.inval 명령어는 mbarrier 객체를 무효화합니다. Multi-CTA 클러스터에서 barrier가 브로드캐스트된 경우(여러 CTA가 공유하는 barrier), inval은 barrier를 소유한 leader CTA에서만 실행해야 합니다.
핵심 코드 분석
Before:
Value pred = getElectWarp0OrThread0(*targetInfo, b);
// 모든 CTA에서 실행
ptxBuilder.create(ptx);
barSyncOp({ptxBuilder.newOperand(pred, "b"),
ptxBuilder.newOperand(smemObj.getBase(), "r")},
/*onlyAttachMLIRArgs=*/true);
After:
Value pred = getElectWarp0OrThread0(*targetInfo, b);
if (auto leaderPred =
LLVM::NVIDIA::getLeaderCTAPredicate(loc, rewriter, barrierTy))
pred = b.and_(pred, *leaderPred);
Value barrierPtr = LLVM::NVIDIA::getLeaderAddress(
loc, rewriter, smemObj.getBase(), barrierTy);
// leader CTA에서만 실행
barSyncOp({ptxBuilder.newOperand(pred, "b"),
ptxBuilder.newOperand(barrierPtr, "r")},
/*onlyAttachMLIRArgs=*/true);
getLeaderCTAPredicate: 현재 CTA가 barrier의 소유자(leader)인지 확인하는 predicate를 생성합니다.getLeaderAddress: Barrier의 물리적 주소를 leader CTA의 shared memory 주소로 변환합니다.- 두 predicate를
AND하여 leader CTA의 대표 스레드만mbarrier.inval을 실행합니다.
왜 이게 좋은가
- 정확성: Broadcasted barrier를 여러 CTA에서 동시에 invalidate하면 undefined behavior가 발생할 수 있습니다. Leader만 실행하여 이를 방지합니다.
- 간결한 구현: 기존
getLeaderCTAPredicate/getLeaderAddress유틸리티를 재활용합니다.
정리
Multi-CTA에서 broadcasted mbarrier의 invalidation을 leader CTA에서만 수행하도록 predicate와 주소를 올바르게 설정한 간결한 수정입니다.
참고 자료
이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] WSSpecialize에서 초기화된 Barrier의 Invalidation 추가
- 현재글 : [triton] NVIDIA inval_barrier를 leader CTA에서만 실행하도록 변경
- 다음글 [Ray] 파이프라인 최적 처리량 계산 유틸리티 함수 추가
댓글