본문으로 건너뛰기

[Triton] AMD에서 non-integer 타입 atomic-cas 시 컴파일러 크래시 수정

PR 링크: triton-lang/triton#9116 상태: Merged | 변경: +46 / -4

들어가며

AMD GPU에서 tt.atomic_cas를 float 타입에 대해 사용하면 컴파일러가 core dump를 발생시켰다. 이는 LLVM의 cmpxchg 명령어가 integer 타입만 지원하기 때문이다. 이 PR은 non-integer 타입의 atomic CAS 시 integer로 bitcast하여 cmpxchg를 수행하고, 결과를 다시 원래 타입으로 bitcast하는 처리를 추가한다.

핵심 코드 분석

Before: non-integer 타입을 그대로 cmpxchg에 전달

Value casVal = valElements[i];
Value casCmp = cmpElements[i];
Value casPtr = ptrElements[i];
// casVal이 f32이면 cmpxchg가 실패 -> core dump
auto cmpxchg = b.create_cmpxchg(casPtr, casCmp, casVal, ...);
Value ret = b.extract_val(valueElemTy, cmpxchg, 0);

After: integer bitcast 래핑

Value casVal = valElements[i];
Value casCmp = cmpElements[i];
Value casPtr = ptrElements[i];

Type valueElemIntTy{};
if (!valueElemTy.isSignlessInteger()) {
  valueElemIntTy = rewriter.getIntegerType(valueElemNBits);
}

if (valueElemIntTy) {
  casVal = LLVM::BitcastOp::create(rewriter, loc, valueElemIntTy, casVal);
  casCmp = LLVM::BitcastOp::create(rewriter, loc, valueElemIntTy, casCmp);
}

auto cmpxchg = b.create_cmpxchg(casPtr, casCmp, casVal, ...);

Value ret;
if (valueElemIntTy) {
  ret = b.extract_val(valueElemIntTy, cmpxchg, 0);
  ret = LLVM::BitcastOp::create(rewriter, loc, valueElemTy, ret);
} else {
  ret = b.extract_val(valueElemTy, cmpxchg, 0);
}

생성되는 LLVM IR

// f32 atomic CAS -> i32 bitcast로 수행
%c64i = llvm.bitcast %c64 : f32 to i32
%c32i = llvm.bitcast %c32 : f32 to i32
%cmpxchg = llvm.cmpxchg %ptr, %c32i, %c64i acquire monotonic
%resi = llvm.extractvalue %cmpxchg[0] : !llvm.struct<(i32, i1)>
%res = llvm.bitcast %resi : i32 to f32

왜 이게 좋은가

  1. 크래시 방지: float, bfloat16 등 non-integer 타입의 atomic CAS가 정상 동작한다.
  2. 표준 패턴: integer bitcast를 통한 atomic CAS는 LLVM/CUDA에서 널리 사용되는 패턴이다.
  3. tensor/scalar 모두 처리: tensor atomic CAS와 scalar atomic CAS 두 경로 모두 수정했다.
  4. CDNA2 flaky 테스트 분리: CDNA2에서 flaky한 테스트는 skip으로 처리하여 CI 안정성을 확보했다.

정리

이 PR은 AMD GPU에서 non-integer 타입(f32, f64, bf16 등)에 대한 atomic CAS 시 컴파일러가 core dump하는 문제를 수정했다. cmpxchg 전후로 integer bitcast를 삽입하여 LLVM의 integer-only 제약을 우회한다.

참고 자료


이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.

댓글

관련 포스트

PR Analysis 의 다른글