[Triton] MXFP4→BF16 변환에서 mul.bf16x2 강제 사용 — 1% MoE 성능 향상
PR 링크: triton-lang/triton#8967 상태: Merged | 변경: +17 / -1
들어가며
Triton에서 MXFP4(Microscaled FP4) 데이터를 BF16으로 변환할 때 스케일 값을 곱하는 단계가 있다. LLVM은 이 곱셈을 자동 벡터화할 때 스케일 브로드캐스팅 때문에 비효율적인 코드를 생성한다. 이 PR은 inline assembly로 mul.bf16x2를 직접 호출하여, ptxas가 HMUL2.BF16_V2 ... R.H0_H0 같은 최적 명령어를 생성하도록 유도한다.
핵심 코드 분석
Before
# Combine scale and x
x = x * scale
LLVM이 벡터화에 실패하여 scalar/vector mul이 섞인 비효율적 코드가 생성된다.
After
@triton.jit
def mul_bf16x2(a, b):
use_mul: tl.constexpr = cuda_capability_geq(9)
op_instr: tl.constexpr = "mul.bf16x2" if use_mul else "fma.rn.bf16x2"
op_suffix: tl.constexpr = "" if use_mul else ", z"
return tl.inline_asm_elementwise(
asm=f"{op_instr} $0, $1, $2{op_suffix};",
constraints="=r,r,r",
args=[a, b],
dtype=tl.bfloat16,
is_pure=True,
pack=2,
)
# Combine scale and x
x = mul_bf16x2(x, scale)
pack=2로 2개의 bf16 값을 하나의 32비트 레지스터에 묶어 처리한다.
왜 이게 좋은가
- ptxas 최적화 활용:
mul.bf16x2PTX 명령어에서 ptxas가HMUL2.BF16_V2 R90, R90, R100.H0_H0처럼 스케일 브로드캐스팅을 레지스터 modifier로 융합한다. - 아키텍처 분기: SM90+(Hopper) 이상에서는
mul.bf16x2를, 그 이하에서는fma.rn.bf16x2를 사용한다. - 실측 성능: non-persistent bf16xmxfp4 MoE에서 1%의 속도 향상이 측정되었다.
정리
컴파일러의 자동 벡터화가 실패하는 경우, inline assembly로 직접 원하는 명령어를 지정하는 것이 효과적인 우회 방법이다. 특히 ptxas가 후속 최적화를 수행할 수 있는 형태의 PTX를 생성하는 것이 핵심이다.
참고 자료
이 글은 AI 도구의 도움을 받아 작성되었습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Ray Data] LLM 배치 추론에서 개별 행 실패 시에도 작업을 계속하는 에러 핸들링 추가
- 현재글 : [Triton] MXFP4→BF16 변환에서 mul.bf16x2 강제 사용 — 1% MoE 성능 향상
- 다음글 [Triton] WGMMA register pipelining에서 누락된 wait 삽입 수정
댓글