본문으로 건너뛰기

[Triton] SwiGLU exp2 최적화 부분 롤백 — 수치 정확도 우선

PR 링크: triton-lang/triton#8905 상태: Merged | 변경: +5 / -2

들어가며

SwiGLU(Swish-Gated Linear Unit)는 LLM에서 널리 사용되는 활성화 함수다. 이전 PR #8801에서는 tl.exp(-alpha * gelu)exp2_ftz((-alpha * log2_e) * gelu)로 변환하여, PTX 수준에서 상수 폴딩을 유도하는 최적화를 도입했다. 그러나 일부 모델에서 수치 차이가 관측되어 이 PR에서 롤백한다.

핵심 코드 분석

Before (최적화된 버전)

# exp(x) → exp2(log2(e) * x), (-alpha * log2_e)를 상수로 폴딩
log2_e: tl.constexpr = 1.4426950408889634
s = gelu / (1 + exp2_ftz((-alpha * log2_e) * gelu))

After (롤백)

s = gelu / (1 + tl.exp(-alpha * gelu))

# TODO: Instead of using tl.exp(-alpha * gelu),
# there is potential way to reduce instructions:
# But we need to further understand its impact on model numerics.
# log2_e: tl.constexpr = 1.4426950408889634
# s = gelu / (1 + exp2_ftz((-alpha * log2_e) * gelu))

왜 이게 좋은가

  • 안전한 롤백: 수치 정확도는 ML 모델에서 결과 품질에 직접 영향을 미친다. 성능보다 정확성을 우선시하는 판단이다.
  • 최적화 코드 보존: 주석으로 최적화 코드를 남겨두어, E2E 평가 후 재적용할 수 있는 길을 열어둔다.
  • exp2_ftz vs exp: flush-to-zero 모드의 exp2와 표준 exp는 denormalized number 처리에서 차이가 있으며, 이것이 수치 차이의 원인일 수 있다.

정리

컴파일러 수준의 수학 함수 최적화는 항상 수치 정확도와의 트레이드오프를 수반한다. 이 PR은 "E2E 검증 없이 수학적 등가 변환을 적용하면 안 된다"는 교훈을 보여준다.

참고 자료


이 글은 AI 도구의 도움을 받아 작성되었습니다.

댓글

관련 포스트

PR Analysis 의 다른글