[Triton] FPSan에서 exp/exp2의 대수적 성질을 보존하는 구현
들어가며
Triton의 FPSan(Floating-Point Sanitizer)은 커널 변환의 수학적 정확성을 검증하는 도구다. 핵심 아이디어는 float 연산을 integer 연산으로 매핑(homomorphism)하여 두 커널의 결과를 비트 단위로 비교하는 것이다. 이 PR은 exp와 exp2에 대해 exp2(a + b) = exp2(a) * exp2(b) 항등식을 보존하는 FPSan 변환을 구현한다.
핵심 코드 분석
핵심 수학
exp2(x)를 integer 공간에서 C^x로 구현한다. 여기서 C = 0xa343836d는 5 mod 8인 상수로, uint32의 곱셈군에서 생성원(generator)이다.
Value fpsanExp2FromI32(PatternRewriter &rewriter, Location loc, Value xI, Type floatTy) {
auto one = getIntConstantLike(rewriter, loc, xI.getType(), 1);
auto zero = getIntConstantLike(rewriter, loc, xI.getType(), 0);
auto c = getIntConstantLike(rewriter, loc, xI.getType(), 0xa343836d);
Value y = one;
for (int i = 0; i < 32; ++i) {
y = arith::MulIOp::create(rewriter, loc, y, y); // y = y^2
auto bit = getIntConstantLike(rewriter, loc, xI.getType(), int64_t(1ull << (31 - i)));
auto masked = arith::AndIOp::create(rewriter, loc, xI, bit);
auto isZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, masked, zero);
auto factor = arith::SelectOp::create(rewriter, loc, isZero, one, c);
y = arith::MulIOp::create(rewriter, loc, y, factor); // y *= (bit set ? C : 1)
}
return bitcastToFloat(rewriter, loc, y, floatTy);
}
이것이 작동하는 이유: 1 mod 4인 uint32는 곱셈 아래 순환군을 형성하고, 5 mod 8인 원소(C)는 이 순환군의 생성원이다. x -> C^x는 덧셈군에서 곱셈군으로의 준동형사상(homomorphism)이므로 C^(a+b) = C^a * C^b가 성립한다.
exp(x)는 exp2(x * rcp_log_2)로 구현한다:
Value fpsanExp(PatternRewriter &rewriter, Location loc, Value input) {
auto inputI = bitcastToInt(rewriter, loc, input);
auto rcpLog2 = getIntConstantLike(rewriter, loc, inputI.getType(), 0x3fb8aa3b);
auto scaledI = arith::MulIOp::create(rewriter, loc, inputI, rcpLog2);
return fpsanExp2FromI32(rewriter, loc, scaledI, input.getType());
}
왜 이게 좋은가
- FlashAttention 검증:
exp2(a + b) = exp2(a) * exp2(b)항등식 보존으로, fused attention 커널과 naive attention 커널의 FPSan 결과가 블록 크기나 split-k 전략에 관계없이 일치한다. - Schanuel 추측: PR 저자는 Schanuel 추측이 참이면, ring 연산과 exp로 구성된 모든 수학적으로 유효한 변환이 FPSan에 의해 수용된다고 주장한다.
- 실용성: 테스트에서
exp(x+y)vsexp(x)*exp(y)의 비트 단위 일치를 검증한다.
정리
+215/-8의 변경으로, 수론(number theory)의 순환군 이론을 GPU 커널 검증에 적용한 인상적인 PR이다. 수학적 아름다움과 실용성을 동시에 갖추고 있다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글