본문으로 건너뛰기

[Triton] MXFP4→BF16 변환에서 mul.bf16x2 강제 사용 — 1% MoE 성능 향상

PR 링크: triton-lang/triton#8967 상태: Merged | 변경: +17 / -1

들어가며

Triton에서 MXFP4(Microscaled FP4) 데이터를 BF16으로 변환할 때 스케일 값을 곱하는 단계가 있다. LLVM은 이 곱셈을 자동 벡터화할 때 스케일 브로드캐스팅 때문에 비효율적인 코드를 생성한다. 이 PR은 inline assembly로 mul.bf16x2를 직접 호출하여, ptxas가 HMUL2.BF16_V2 ... R.H0_H0 같은 최적 명령어를 생성하도록 유도한다.

핵심 코드 분석

Before

# Combine scale and x
x = x * scale

LLVM이 벡터화에 실패하여 scalar/vector mul이 섞인 비효율적 코드가 생성된다.

After

@triton.jit
def mul_bf16x2(a, b):
    use_mul: tl.constexpr = cuda_capability_geq(9)
    op_instr: tl.constexpr = "mul.bf16x2" if use_mul else "fma.rn.bf16x2"
    op_suffix: tl.constexpr = "" if use_mul else ", z"

    return tl.inline_asm_elementwise(
        asm=f"{op_instr} $0, $1, $2{op_suffix};",
        constraints="=r,r,r",
        args=[a, b],
        dtype=tl.bfloat16,
        is_pure=True,
        pack=2,
    )

# Combine scale and x
x = mul_bf16x2(x, scale)

pack=2로 2개의 bf16 값을 하나의 32비트 레지스터에 묶어 처리한다.

왜 이게 좋은가

  • ptxas 최적화 활용: mul.bf16x2 PTX 명령어에서 ptxas가 HMUL2.BF16_V2 R90, R90, R100.H0_H0처럼 스케일 브로드캐스팅을 레지스터 modifier로 융합한다.
  • 아키텍처 분기: SM90+(Hopper) 이상에서는 mul.bf16x2를, 그 이하에서는 fma.rn.bf16x2를 사용한다.
  • 실측 성능: non-persistent bf16xmxfp4 MoE에서 1%의 속도 향상이 측정되었다.

정리

컴파일러의 자동 벡터화가 실패하는 경우, inline assembly로 직접 원하는 명령어를 지정하는 것이 효과적인 우회 방법이다. 특히 ptxas가 후속 최적화를 수행할 수 있는 형태의 PTX를 생성하는 것이 핵심이다.

참고 자료


이 글은 AI 도구의 도움을 받아 작성되었습니다.

댓글

관련 포스트

PR Analysis 의 다른글