본문으로 건너뛰기

[triton] AMD AtomicCAS의 Tensor Operand Thread Predicate 수정

PR 링크: triton-lang/triton#9605 상태: Merged | 변경: +52 / -10

들어가며

Triton에서 tensor 기반 atomic 연산은 레이아웃에 따라 특정 스레드만이 실제로 atomic 명령을 실행해야 합니다. redundant thread(동일한 데이터를 중복 처리하는 스레드)가 atomic을 실행하면 데이터 corruption이 발생합니다. 이 PR은 AMD 백엔드의 AtomicCAS 변환에서 누락된 thread predicate와 register mask 처리를 추가합니다.

핵심 코드 분석

Before (thread predicate 없음):

for (size_t i = 0; i < elemsPerThread; i += 1) {
    // 모든 스레드가 무조건 atomic CAS 실행
    auto cmpxchg = LLVM::AtomicCmpXchgOp::create(...);
    resultVals[i] = ret;
}
// threadPred = b.true_val() 로 항상 true
finalizeTensorAtomicResults(op, ..., b.true_val(), ...);

After (thread predicate 및 register mask 적용):

auto freeVarMasks = getFreeVariableMasks(op.getPtr().getType());
Value threadPred = emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo);
uint32_t regMask = freeVarMasks[str_attr("reg")];

for (size_t i = 0; i < elemsPerThread; i += 1) {
    // register zero base로 인한 중복 건너뛰기
    if (tensorTy && (i & ~regMask) != i) {
        resultVals[i] = resultVals[i & ~regMask];
        continue;
    }
    // thread predicate로 브랜치
    LLVM::CondBrOp::create(rewriter, loc, threadPred, atomicBlock, endBlock, undefVal);
    // atomic CAS 실행
    auto cmpxchg = LLVM::AtomicCmpXchgOp::create(...);
    LLVM::BrOp::create(rewriter, loc, ret, endBlock);
    resultVals[i] = endBlock->getArgument(0);
}
finalizeTensorAtomicResults(op, ..., threadPred, ...);

왜 이게 좋은가

이 수정은 데이터 정합성 문제를 해결합니다. Free variable mask를 사용하여 broadcast된 레이아웃에서 동일한 메모리 위치에 여러 스레드가 atomic을 시도하는 것을 방지합니다. Register mask를 통해 zero base로 인한 중복 인덱스도 올바르게 처리합니다. endBlock 패턴으로 predicate된 스레드는 undef 값을, 실행된 스레드는 실제 CAS 결과를 받아 phi node 역할을 수행합니다.

정리

AMD AtomicCAS 변환에 redundant thread predicate와 register mask를 추가하여 broadcast 레이아웃에서의 중복 atomic 실행을 방지하고, 결과 취합 로직도 올바르게 수정했습니다.

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글