[triton] AMD AtomicCAS의 Tensor Operand Thread Predicate 수정
PR 링크: triton-lang/triton#9605 상태: Merged | 변경: +52 / -10
들어가며
Triton에서 tensor 기반 atomic 연산은 레이아웃에 따라 특정 스레드만이 실제로 atomic 명령을 실행해야 합니다. redundant thread(동일한 데이터를 중복 처리하는 스레드)가 atomic을 실행하면 데이터 corruption이 발생합니다. 이 PR은 AMD 백엔드의 AtomicCAS 변환에서 누락된 thread predicate와 register mask 처리를 추가합니다.
핵심 코드 분석
Before (thread predicate 없음):
for (size_t i = 0; i < elemsPerThread; i += 1) {
// 모든 스레드가 무조건 atomic CAS 실행
auto cmpxchg = LLVM::AtomicCmpXchgOp::create(...);
resultVals[i] = ret;
}
// threadPred = b.true_val() 로 항상 true
finalizeTensorAtomicResults(op, ..., b.true_val(), ...);
After (thread predicate 및 register mask 적용):
auto freeVarMasks = getFreeVariableMasks(op.getPtr().getType());
Value threadPred = emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo);
uint32_t regMask = freeVarMasks[str_attr("reg")];
for (size_t i = 0; i < elemsPerThread; i += 1) {
// register zero base로 인한 중복 건너뛰기
if (tensorTy && (i & ~regMask) != i) {
resultVals[i] = resultVals[i & ~regMask];
continue;
}
// thread predicate로 브랜치
LLVM::CondBrOp::create(rewriter, loc, threadPred, atomicBlock, endBlock, undefVal);
// atomic CAS 실행
auto cmpxchg = LLVM::AtomicCmpXchgOp::create(...);
LLVM::BrOp::create(rewriter, loc, ret, endBlock);
resultVals[i] = endBlock->getArgument(0);
}
finalizeTensorAtomicResults(op, ..., threadPred, ...);
왜 이게 좋은가
이 수정은 데이터 정합성 문제를 해결합니다. Free variable mask를 사용하여 broadcast된 레이아웃에서 동일한 메모리 위치에 여러 스레드가 atomic을 시도하는 것을 방지합니다. Register mask를 통해 zero base로 인한 중복 인덱스도 올바르게 처리합니다. endBlock 패턴으로 predicate된 스레드는 undef 값을, 실행된 스레드는 실제 CAS 결과를 받아 phi node 역할을 수행합니다.
정리
AMD AtomicCAS 변환에 redundant thread predicate와 register mask를 추가하여 broadcast 레이아웃에서의 중복 atomic 실행을 방지하고, 결과 취합 로직도 올바르게 수정했습니다.
참고 자료
이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] AMD Pipelined Loop에서 TDM Load의 Buffer Race 수정
- 현재글 : [triton] AMD AtomicCAS의 Tensor Operand Thread Predicate 수정
- 다음글 [Uvicorn] bytes에서 bytearray로 변경하여 HTTP 바디 누적 O(n²) → O(n) 개선
댓글