본문으로 건너뛰기

[triton] FpSan - Floating Point Sanitizer 도입

PR 링크: triton-lang/triton#9337 상태: Merged | 변경: +2091 / -1

들어가며

GPU 커널에서 NaN, Inf 같은 부동소수점 이상 값이 발생하면 디버깅이 매우 어렵습니다. 이 PR은 FpSan(Floating Point Sanitizer)이라는 새로운 instrumentation 패스를 Triton에 도입합니다. 컴파일 단계에서 FP 연산을 integer payload 방식으로 rewrite하여, 런타임에 FP 오류를 감지할 수 있게 합니다.

핵심 코드 분석

1. 새로운 MLIR 패스 등록

def TritonInstrumentFpSanitizer: Pass<"tritoninstrument-fp-sanitizer", "mlir::ModuleOp"> {
  let summary = "Replace floating-point ops with integer-payload equivalents";
  let description = "Rewrite selected floating-point operations to use integer
    payload semantics for fpsan.";
  let dependentDialects = ["mlir::arith::ArithDialect",
                           "mlir::math::MathDialect",
                           "mlir::triton::gpu::TritonGPUDialect",
                           "mlir::triton::TritonDialect",
                           "mlir::triton::instrument::TritonInstrumentDialect"];
}

2. Unary 연산 ID 열거형

enum class UnaryOpId : uint64_t {
  Exp = 0, Log, Exp2, Log2, Cos, Sin,
  Sqrt, Rsqrt, Erf, Floor, Ceil,
  PreciseSqrt, DivInv,
};

각 부동소수점 단항 연산에 고유 ID를 부여하여, 런타임 리포트에서 어떤 연산이 문제를 일으켰는지 식별합니다.

3. TMEM Scratch 관리를 통한 상태 저장

class TmemScratchManager {
public:
  static ttg::BlockedEncodingAttr
  getOptimizedBlockedEncoding(PatternRewriter &rewriter,
                              ArrayRef<int64_t> shape, Type elemType) {
    int numWarps = ttg::lookupNumWarps(rewriter.getInsertionBlock()->getParent());
    int threadsPerWarp = ttg::lookupThreadsPerWarp(rewriter);
    // ... vectorized 128-bit access를 위한 최적화된 layout 계산
    unsigned maxElems = std::max(128u / elemBits, 1u);
    // ...
  }
};

FpSan은 Tensor Memory(TMEM)의 내용을 global scratch memory에 복사하여 검사합니다. TmemScratchManager는 이 scratch buffer를 할당하고 캐싱하여 중복 할당을 방지합니다.

왜 이게 좋은가

  • 디버깅 효율 향상: NaN/Inf 발생 지점을 연산 레벨에서 정확히 추적할 수 있습니다.
  • zero-overhead 원칙: 활성화하지 않으면 성능 영향이 없고, 활성화 시에도 최적화된 scratch 관리로 오버헤드를 최소화합니다.
  • MLIR 패스 기반: 기존 컴파일 파이프라인에 자연스럽게 통합되며, 다른 sanitizer(ConSan 등)와 동일한 아키텍처를 따릅니다.

정리

FpSan은 Triton GPU 커널의 부동소수점 연산 오류를 런타임에 감지하는 새로운 도구입니다. MLIR 패스 시스템을 활용하여 FP 연산을 instrumented 버전으로 rewrite하고, global scratch memory를 통해 상태를 추적합니다. 1100줄 이상의 새 코드가 추가된 대규모 기능입니다.

참고 자료


이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글