본문으로 건너뛰기

[Triton] FPSan에서 exp/exp2의 대수적 성질을 보존하는 구현

들어가며

Triton의 FPSan(Floating-Point Sanitizer)은 커널 변환의 수학적 정확성을 검증하는 도구다. 핵심 아이디어는 float 연산을 integer 연산으로 매핑(homomorphism)하여 두 커널의 결과를 비트 단위로 비교하는 것이다. 이 PR은 expexp2에 대해 exp2(a + b) = exp2(a) * exp2(b) 항등식을 보존하는 FPSan 변환을 구현한다.

핵심 코드 분석

핵심 수학

exp2(x)를 integer 공간에서 C^x로 구현한다. 여기서 C = 0xa343836d는 5 mod 8인 상수로, uint32의 곱셈군에서 생성원(generator)이다.

Value fpsanExp2FromI32(PatternRewriter &rewriter, Location loc, Value xI, Type floatTy) {
    auto one = getIntConstantLike(rewriter, loc, xI.getType(), 1);
    auto zero = getIntConstantLike(rewriter, loc, xI.getType(), 0);
    auto c = getIntConstantLike(rewriter, loc, xI.getType(), 0xa343836d);

    Value y = one;
    for (int i = 0; i < 32; ++i) {
        y = arith::MulIOp::create(rewriter, loc, y, y);        // y = y^2
        auto bit = getIntConstantLike(rewriter, loc, xI.getType(), int64_t(1ull << (31 - i)));
        auto masked = arith::AndIOp::create(rewriter, loc, xI, bit);
        auto isZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, masked, zero);
        auto factor = arith::SelectOp::create(rewriter, loc, isZero, one, c);
        y = arith::MulIOp::create(rewriter, loc, y, factor);   // y *= (bit set ? C : 1)
    }
    return bitcastToFloat(rewriter, loc, y, floatTy);
}

이것이 작동하는 이유: 1 mod 4인 uint32는 곱셈 아래 순환군을 형성하고, 5 mod 8인 원소(C)는 이 순환군의 생성원이다. x -> C^x는 덧셈군에서 곱셈군으로의 준동형사상(homomorphism)이므로 C^(a+b) = C^a * C^b가 성립한다.

exp(x)exp2(x * rcp_log_2)로 구현한다:

Value fpsanExp(PatternRewriter &rewriter, Location loc, Value input) {
    auto inputI = bitcastToInt(rewriter, loc, input);
    auto rcpLog2 = getIntConstantLike(rewriter, loc, inputI.getType(), 0x3fb8aa3b);
    auto scaledI = arith::MulIOp::create(rewriter, loc, inputI, rcpLog2);
    return fpsanExp2FromI32(rewriter, loc, scaledI, input.getType());
}

왜 이게 좋은가

  • FlashAttention 검증: exp2(a + b) = exp2(a) * exp2(b) 항등식 보존으로, fused attention 커널과 naive attention 커널의 FPSan 결과가 블록 크기나 split-k 전략에 관계없이 일치한다.
  • Schanuel 추측: PR 저자는 Schanuel 추측이 참이면, ring 연산과 exp로 구성된 모든 수학적으로 유효한 변환이 FPSan에 의해 수용된다고 주장한다.
  • 실용성: 테스트에서 exp(x+y) vs exp(x)*exp(y)의 비트 단위 일치를 검증한다.

정리

+215/-8의 변경으로, 수론(number theory)의 순환군 이론을 GPU 커널 검증에 적용한 인상적인 PR이다. 수학적 아름다움과 실용성을 동시에 갖추고 있다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글