본문으로 건너뛰기

[Triton] RDNA에서 bf16 곱셈에 V_DOT2_BF16_BF16 명령어 활용

들어가며

AMD RDNA 아키텍처(gfx11/gfx12)는 bf16 곱셈을 하드웨어에서 직접 지원하지 않는다. 기존에는 fp32로 upcast하고 곱셈 후 다시 bf16으로 downcast하는데, 이 downcast 과정이 비용이 높다. 이 PR은 V_DOT2_BF16_BF16 명령어를 활용하여 downcast 없이 bf16 곱셈 결과를 얻는 트릭을 적용한다.

핵심 코드 분석

Before

SmallVector<Value> createDestOps(...) {
    if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
        return {EmitDualBF16ElementwiseOp<LLVM::FMulOp>(loc, rewriter, operands)};
    }
}

bf16 값을 fp32로 변환한 뒤 fp32 곱셈을 하고, 다시 bf16으로 downcast했다. downcast에서 round-to-nearest-even이 소프트웨어로 구현되어 느렸다.

After

if (isRDNA(isaFamily)) {
    // V_DOT2_BF16_BF16: res = a*b + 0*0 + 0
    auto b = TritonLLVMOpBuilder(loc, rewriter);
    Value aVal = packLLVector(loc, ValueRange{operands[0][0], b.bf16_val(0.0)}, rewriter);
    Value bVal = packLLVector(loc, ValueRange{operands[0][1], b.bf16_val(0.0)}, rewriter);
    return {LLVM::createLLVMIntrinsicCallOp(
                rewriter, loc, "llvm.amdgcn.fdot2.bf16.bf16", bf16_ty,
                ValueRange{aVal, bVal, b.bf16_val(0.0)})
                ->getResult(0)};
}

V_DOT2_BF16_BF16(a, b, c) = a[0]*b[0] + a[1]*b[1] + c를 활용한다. a = [x, 0], b = [y, 0], c = 0으로 설정하면 결과가 x*y + 0*0 + 0 = x*y가 된다. 하드웨어가 round-to-nearest-even을 수행하므로 기존 구현과 bit-for-bit 동일한 결과를 보장하면서 더 빠르다.

왜 이게 좋은가

  • 성능 향상: 소프트웨어 downcast를 하드웨어 dot product로 대체하여 속도가 개선된다.
  • 정확성 보장: bit-for-bit 동일한 결과를 낸다고 PR description에서 명시하고 있다.
  • 최소 변경: +38/-3으로 핵심 로직만 추가했다.

정리

하드웨어 명령어의 창의적 활용으로 성능을 개선한 사례다. dot product를 "곱셈기"로 재해석하는 트릭이 인상적이다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글