[triton] SwiGLU 커널에 ex2.approx.ftz 적용으로 1-2 GBps 성능 개선
PR 링크: triton-lang/triton#9164 상태: Merged | 변경: +5 / -10
들어가며
SwiGLU(Swish-Gated Linear Unit)는 LLM에서 널리 사용되는 활성화 함수입니다. 내부적으로 exp 연산을 포함하는데, GPU에서 이를 최적화하면 전체 처리량에 미미하지만 반복적인 개선을 얻을 수 있습니다. 이 PR은 CUDA의 근사 지수 함수(ex2.approx.ftz)를 활용하여 약 1-2 GBps의 성능 향상을 달성합니다.
핵심 코드 분석
Before:
@triton.jit
def exp2_ftz(x):
if tl.target_info.is_cuda():
return tl.inline_asm_elementwise(
"ex2.approx.ftz.f32 $0, $1;",
"=r, r", args=[x], dtype=tl.float32, is_pure=True, pack=1,
)
else:
return tl.exp2(x)
def compute_swiglu(gelu, linear, scale, alpha, limit):
# ...
s = gelu / (1 + tl.exp(-alpha * gelu))
# TODO: exp(x) -> exp2(log2(e) * x) 최적화 가능
After:
@triton.jit
def exp_ftz(x):
if tl.target_info.is_cuda():
log2_e: tl.constexpr = 1.4426950408889634
x *= log2_e
return tl.inline_asm_elementwise(
"ex2.approx.ftz.f32 $0, $1;",
"=r, r", args=[x], dtype=tl.float32, is_pure=True, pack=1,
)
else:
return tl.exp(x)
def compute_swiglu(gelu, linear, scale, alpha, limit):
# ...
s = gelu / (1 + exp_ftz(-alpha * gelu))
기존 TODO 주석에 언급된 최적화를 실제로 구현했습니다. exp(x) = exp2(log2(e) * x) 변환을 적용하여, ex2.approx.ftz PTX 명령어를 사용합니다. .ftz(flush-to-zero)는 denormal 값을 0으로 처리하는데, SwiGLU에서 1 + exp(...) 형태이므로 denormal은 어차피 rounding으로 소실됩니다.
왜 이게 좋은가
이 최적화는 수치 안전성 분석이 뒷받침된 근사 최적화입니다. 단순히 근사 함수를 사용하는 것이 아니라, SwiGLU의 수학적 구조(1 + exp(...))에서 denormal이 의미 없다는 점을 논증한 후 적용했습니다. bf16 x mxfp4 MoE에서 약 0.1%의 개선이며, 절대적 수치는 작지만 반복 가능(repeatable)한 개선입니다. 이런 미세 최적화가 대규모 LLM 추론에서 누적되면 의미 있는 차이를 만듭니다.
정리
tl.exp을exp2(log2(e) * x)+ex2.approx.ftz.f32로 교체- 함수명을
exp2_ftz에서exp_ftz로 변경하여 의미 명확화 - denormal flush가 SwiGLU 수치에 영향 없음을 확인
- bf16 x mxfp4 MoE에서 약 1-2 GBps 개선
참고 자료
이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Ray Train] 벤치마크에 첫 번째 배치 시간 포함하여 정확한 처리량 측정
- 현재글 : [triton] SwiGLU 커널에 ex2.approx.ftz 적용으로 1-2 GBps 성능 개선
- 다음글 [vllm] gRPC Server Entrypoint - 고성능 gRPC 서빙 지원
댓글