본문으로 건너뛰기

[Triton] clamp 최적화를 scalar에도 적용 — fmin.xorsign.abs 활용

PR 링크: triton-lang/triton#8796 상태: Merged | 변경: +43 / -23

들어가며

Triton은 tt.clampf(x, -limit, limit) 패턴을 감지하여 NVIDIA Hopper의 fmin.xorsign.abs PTX 명령어로 변환하는 최적화를 가지고 있다. 이 명령어는 clamp를 단일 명령어로 수행하여 3개의 비교/선택 명령어를 대체한다. 그러나 기존 구현은 텐서에만 동작했고, scalar 값에는 적용되지 않았다.

핵심 코드 분석

Before

auto getSplatInitializer = [](Value v) -> std::optional<double> {
  if (auto constOp = v.getDefiningOp<arith::ConstantOp>()) {
    if (auto attr = mlir::dyn_cast<DenseIntOrFPElementsAttr>(
            constOp.getValueAttr())) {
      if (attr.isSplat()) {
        return attr.getSplatValue<APFloat>().convertToDouble();
      }
    }
  }
  return std::nullopt;
};

DenseIntOrFPElementsAttr만 처리하므로 scalar FloatAttr는 매칭되지 않는다.

After

auto getSplatInitializer = [](Value v) -> std::optional<double> {
  DenseIntOrFPElementsAttr denseAttr;
  if (matchPattern(v, m_Constant(&denseAttr))) {
    if (denseAttr.isSplat()) {
      return denseAttr.getSplatValue<APFloat>().convertToDouble();
    }
    return std::nullopt;
  }
  FloatAttr floatAttr;
  if (matchPattern(v, m_Constant(&floatAttr))) {
    return floatAttr.getValue().convertToDouble();
  }
  return std::nullopt;
};

matchPatternFloatAttr 분기를 추가하여 scalar 상수도 매칭한다.

왜 이게 좋은가

  • 일관성: 텐서와 scalar 모두 동일한 최적화가 적용되어, 코드 형태에 따른 성능 차이가 사라진다.
  • 코드 정리: matchPattern 사용으로 패턴 매칭 코드가 더 간결하고 MLIR 관용적이 되었다.
  • 테스트 추가: scalar clamp에 대한 nvvm.fmin.xorsign.abs.f 생성을 검증하는 lit test가 포함되었다.

정리

단순한 타입 분기 추가지만, "텐서에만 동작하는 최적화"라는 불필요한 제약을 제거하여 최적화의 적용 범위를 넓혔다.

참고 자료


이 글은 AI 도구의 도움을 받아 작성되었습니다.

댓글

관련 포스트

PR Analysis 의 다른글