[Triton] AMD에서 non-integer 타입 atomic-cas 시 컴파일러 크래시 수정
PR 링크: triton-lang/triton#9116 상태: Merged | 변경: +46 / -4
들어가며
AMD GPU에서 tt.atomic_cas를 float 타입에 대해 사용하면 컴파일러가 core dump를 발생시켰다. 이는 LLVM의 cmpxchg 명령어가 integer 타입만 지원하기 때문이다. 이 PR은 non-integer 타입의 atomic CAS 시 integer로 bitcast하여 cmpxchg를 수행하고, 결과를 다시 원래 타입으로 bitcast하는 처리를 추가한다.
핵심 코드 분석
Before: non-integer 타입을 그대로 cmpxchg에 전달
Value casVal = valElements[i];
Value casCmp = cmpElements[i];
Value casPtr = ptrElements[i];
// casVal이 f32이면 cmpxchg가 실패 -> core dump
auto cmpxchg = b.create_cmpxchg(casPtr, casCmp, casVal, ...);
Value ret = b.extract_val(valueElemTy, cmpxchg, 0);
After: integer bitcast 래핑
Value casVal = valElements[i];
Value casCmp = cmpElements[i];
Value casPtr = ptrElements[i];
Type valueElemIntTy{};
if (!valueElemTy.isSignlessInteger()) {
valueElemIntTy = rewriter.getIntegerType(valueElemNBits);
}
if (valueElemIntTy) {
casVal = LLVM::BitcastOp::create(rewriter, loc, valueElemIntTy, casVal);
casCmp = LLVM::BitcastOp::create(rewriter, loc, valueElemIntTy, casCmp);
}
auto cmpxchg = b.create_cmpxchg(casPtr, casCmp, casVal, ...);
Value ret;
if (valueElemIntTy) {
ret = b.extract_val(valueElemIntTy, cmpxchg, 0);
ret = LLVM::BitcastOp::create(rewriter, loc, valueElemTy, ret);
} else {
ret = b.extract_val(valueElemTy, cmpxchg, 0);
}
생성되는 LLVM IR
// f32 atomic CAS -> i32 bitcast로 수행
%c64i = llvm.bitcast %c64 : f32 to i32
%c32i = llvm.bitcast %c32 : f32 to i32
%cmpxchg = llvm.cmpxchg %ptr, %c32i, %c64i acquire monotonic
%resi = llvm.extractvalue %cmpxchg[0] : !llvm.struct<(i32, i1)>
%res = llvm.bitcast %resi : i32 to f32
왜 이게 좋은가
- 크래시 방지: float, bfloat16 등 non-integer 타입의 atomic CAS가 정상 동작한다.
- 표준 패턴: integer bitcast를 통한 atomic CAS는 LLVM/CUDA에서 널리 사용되는 패턴이다.
- tensor/scalar 모두 처리: tensor atomic CAS와 scalar atomic CAS 두 경로 모두 수정했다.
- CDNA2 flaky 테스트 분리: CDNA2에서 flaky한 테스트는 skip으로 처리하여 CI 안정성을 확보했다.
정리
이 PR은 AMD GPU에서 non-integer 타입(f32, f64, bf16 등)에 대한 atomic CAS 시 컴파일러가 core dump하는 문제를 수정했다. cmpxchg 전후로 integer bitcast를 삽입하여 LLVM의 integer-only 제약을 우회한다.
참고 자료
이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.
관련 포스트
- [triton] AMD Canonicalize Pointers에서 arith.select의 비대칭 fat pointer 처리 강화
- [triton] AMD FpSan dot 에뮬레이션의 MFMA/WMMA encoding 호환성 수정
- [triton] AMD BlockPingpong 패스의 non-MFMA dot 크래시 수정
- [Triton] HIPBackend에서 import torch 가드 추가 — JAX 호환성 복원
- [triton] AMD: PartitionedSharedEncodingAttr의 LLVM lowering 지원으로 공유 메모리 파티셔닝 구현
PR Analysis 의 다른글
- 이전글 [Triton] LLVM Debug Information에서 커널 인자 누락 수정
- 현재글 : [Triton] AMD에서 non-integer 타입 atomic-cas 시 컴파일러 크래시 수정
- 다음글 [pytest] actions/cache v4에서 v5로 업그레이드
댓글