[Triton] RDNA에서 bf16 곱셈에 V_DOT2_BF16_BF16 명령어 활용
들어가며
AMD RDNA 아키텍처(gfx11/gfx12)는 bf16 곱셈을 하드웨어에서 직접 지원하지 않는다. 기존에는 fp32로 upcast하고 곱셈 후 다시 bf16으로 downcast하는데, 이 downcast 과정이 비용이 높다. 이 PR은 V_DOT2_BF16_BF16 명령어를 활용하여 downcast 없이 bf16 곱셈 결과를 얻는 트릭을 적용한다.
핵심 코드 분석
Before
SmallVector<Value> createDestOps(...) {
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
return {EmitDualBF16ElementwiseOp<LLVM::FMulOp>(loc, rewriter, operands)};
}
}
bf16 값을 fp32로 변환한 뒤 fp32 곱셈을 하고, 다시 bf16으로 downcast했다. downcast에서 round-to-nearest-even이 소프트웨어로 구현되어 느렸다.
After
if (isRDNA(isaFamily)) {
// V_DOT2_BF16_BF16: res = a*b + 0*0 + 0
auto b = TritonLLVMOpBuilder(loc, rewriter);
Value aVal = packLLVector(loc, ValueRange{operands[0][0], b.bf16_val(0.0)}, rewriter);
Value bVal = packLLVector(loc, ValueRange{operands[0][1], b.bf16_val(0.0)}, rewriter);
return {LLVM::createLLVMIntrinsicCallOp(
rewriter, loc, "llvm.amdgcn.fdot2.bf16.bf16", bf16_ty,
ValueRange{aVal, bVal, b.bf16_val(0.0)})
->getResult(0)};
}
V_DOT2_BF16_BF16(a, b, c) = a[0]*b[0] + a[1]*b[1] + c를 활용한다. a = [x, 0], b = [y, 0], c = 0으로 설정하면 결과가 x*y + 0*0 + 0 = x*y가 된다. 하드웨어가 round-to-nearest-even을 수행하므로 기존 구현과 bit-for-bit 동일한 결과를 보장하면서 더 빠르다.
왜 이게 좋은가
- 성능 향상: 소프트웨어 downcast를 하드웨어 dot product로 대체하여 속도가 개선된다.
- 정확성 보장: bit-for-bit 동일한 결과를 낸다고 PR description에서 명시하고 있다.
- 최소 변경: +38/-3으로 핵심 로직만 추가했다.
정리
하드웨어 명령어의 창의적 활용으로 성능을 개선한 사례다. dot product를 "곱셈기"로 재해석하는 트릭이 인상적이다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글