본문으로 건너뛰기

[triton] SwiGLU 커널에 ex2.approx.ftz 적용으로 1-2 GBps 성능 개선

PR 링크: triton-lang/triton#9164 상태: Merged | 변경: +5 / -10

들어가며

SwiGLU(Swish-Gated Linear Unit)는 LLM에서 널리 사용되는 활성화 함수입니다. 내부적으로 exp 연산을 포함하는데, GPU에서 이를 최적화하면 전체 처리량에 미미하지만 반복적인 개선을 얻을 수 있습니다. 이 PR은 CUDA의 근사 지수 함수(ex2.approx.ftz)를 활용하여 약 1-2 GBps의 성능 향상을 달성합니다.

핵심 코드 분석

Before:

@triton.jit
def exp2_ftz(x):
    if tl.target_info.is_cuda():
        return tl.inline_asm_elementwise(
            "ex2.approx.ftz.f32 $0, $1;",
            "=r, r", args=[x], dtype=tl.float32, is_pure=True, pack=1,
        )
    else:
        return tl.exp2(x)

def compute_swiglu(gelu, linear, scale, alpha, limit):
    # ...
    s = gelu / (1 + tl.exp(-alpha * gelu))
    # TODO: exp(x) -> exp2(log2(e) * x) 최적화 가능

After:

@triton.jit
def exp_ftz(x):
    if tl.target_info.is_cuda():
        log2_e: tl.constexpr = 1.4426950408889634
        x *= log2_e
        return tl.inline_asm_elementwise(
            "ex2.approx.ftz.f32 $0, $1;",
            "=r, r", args=[x], dtype=tl.float32, is_pure=True, pack=1,
        )
    else:
        return tl.exp(x)

def compute_swiglu(gelu, linear, scale, alpha, limit):
    # ...
    s = gelu / (1 + exp_ftz(-alpha * gelu))

기존 TODO 주석에 언급된 최적화를 실제로 구현했습니다. exp(x) = exp2(log2(e) * x) 변환을 적용하여, ex2.approx.ftz PTX 명령어를 사용합니다. .ftz(flush-to-zero)는 denormal 값을 0으로 처리하는데, SwiGLU에서 1 + exp(...) 형태이므로 denormal은 어차피 rounding으로 소실됩니다.

왜 이게 좋은가

이 최적화는 수치 안전성 분석이 뒷받침된 근사 최적화입니다. 단순히 근사 함수를 사용하는 것이 아니라, SwiGLU의 수학적 구조(1 + exp(...))에서 denormal이 의미 없다는 점을 논증한 후 적용했습니다. bf16 x mxfp4 MoE에서 약 0.1%의 개선이며, 절대적 수치는 작지만 반복 가능(repeatable)한 개선입니다. 이런 미세 최적화가 대규모 LLM 추론에서 누적되면 의미 있는 차이를 만듭니다.

정리

  • tl.expexp2(log2(e) * x) + ex2.approx.ftz.f32로 교체
  • 함수명을 exp2_ftz에서 exp_ftz로 변경하여 의미 명확화
  • denormal flush가 SwiGLU 수치에 영향 없음을 확인
  • bf16 x mxfp4 MoE에서 약 1-2 GBps 개선

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글