[triton] FpSan - Floating Point Sanitizer 도입
PR 링크: triton-lang/triton#9337 상태: Merged | 변경: +2091 / -1
들어가며
GPU 커널에서 NaN, Inf 같은 부동소수점 이상 값이 발생하면 디버깅이 매우 어렵습니다. 이 PR은 FpSan(Floating Point Sanitizer)이라는 새로운 instrumentation 패스를 Triton에 도입합니다. 컴파일 단계에서 FP 연산을 integer payload 방식으로 rewrite하여, 런타임에 FP 오류를 감지할 수 있게 합니다.
핵심 코드 분석
1. 새로운 MLIR 패스 등록
def TritonInstrumentFpSanitizer: Pass<"tritoninstrument-fp-sanitizer", "mlir::ModuleOp"> {
let summary = "Replace floating-point ops with integer-payload equivalents";
let description = "Rewrite selected floating-point operations to use integer
payload semantics for fpsan.";
let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::math::MathDialect",
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect",
"mlir::triton::instrument::TritonInstrumentDialect"];
}
2. Unary 연산 ID 열거형
enum class UnaryOpId : uint64_t {
Exp = 0, Log, Exp2, Log2, Cos, Sin,
Sqrt, Rsqrt, Erf, Floor, Ceil,
PreciseSqrt, DivInv,
};
각 부동소수점 단항 연산에 고유 ID를 부여하여, 런타임 리포트에서 어떤 연산이 문제를 일으켰는지 식별합니다.
3. TMEM Scratch 관리를 통한 상태 저장
class TmemScratchManager {
public:
static ttg::BlockedEncodingAttr
getOptimizedBlockedEncoding(PatternRewriter &rewriter,
ArrayRef<int64_t> shape, Type elemType) {
int numWarps = ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent());
int threadsPerWarp = ttg::lookupThreadsPerWarp(rewriter);
// ... vectorized 128-bit access를 위한 최적화된 layout 계산
unsigned maxElems = std::max(128u / elemBits, 1u);
// ...
}
};
FpSan은 Tensor Memory(TMEM)의 내용을 global scratch memory에 복사하여 검사합니다. TmemScratchManager는 이 scratch buffer를 할당하고 캐싱하여 중복 할당을 방지합니다.
왜 이게 좋은가
- 디버깅 효율 향상: NaN/Inf 발생 지점을 연산 레벨에서 정확히 추적할 수 있습니다.
- zero-overhead 원칙: 활성화하지 않으면 성능 영향이 없고, 활성화 시에도 최적화된 scratch 관리로 오버헤드를 최소화합니다.
- MLIR 패스 기반: 기존 컴파일 파이프라인에 자연스럽게 통합되며, 다른 sanitizer(ConSan 등)와 동일한 아키텍처를 따릅니다.
정리
FpSan은 Triton GPU 커널의 부동소수점 연산 오류를 런타임에 감지하는 새로운 도구입니다. MLIR 패스 시스템을 활용하여 FP 연산을 instrumented 버전으로 rewrite하고, global scratch memory를 통해 상태를 추적합니다. 1100줄 이상의 새 코드가 추가된 대규모 기능입니다.
참고 자료
이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
- [triton] GSan AxisInfo 기반 Shadow Update 중복 제거로 2~10배 성능 향상
- [Triton] FenceAsync에 비동기 읽기 의존성 추가 — st.shared와 copy_local_to_global 간 정합성 보장
- [triton] Generic Multi-CTA convert_layout 지원
- [triton] Gluon TMA Op Verifier 강화 및 Illegal Instruction Sanitize 모드 추가
- [triton] AMD Canonicalize Pointers에서 arith.select의 비대칭 fat pointer 처리 강화
PR Analysis 의 다른글
- 이전글 [Loki] memory.Bitmap 슬라이싱 지원: 비정렬 오프셋 처리
- 현재글 : [triton] FpSan - Floating Point Sanitizer 도입
- 다음글 [Ray] 메모리 모니터 리팩터링: cgroup 경로 주입으로 테스트 가능성 확보
댓글