본문으로 건너뛰기

[Triton] MXFP 포맷 출력 matmul 버그 2건 수정

PR 링크: triton-lang/triton#8865 상태: Merged | 변경: +9 / -2

들어가며

Triton의 고성능 matmul 커널은 MXFP(Microscaling Floating Point) 포맷으로 결과를 직접 출력하는 epilogue를 지원한다. 이 PR은 MXFP downcast epilogue에서 발생하는 두 가지 버그를 수정한다: (1) scale 마스크 계산에서 잘못된 전역 차원 N 사용, (2) block_m=64에서 shared memory overflow.

핵심 코드 분석

Bug 1: Scale mask 계산 오류

Before:

N_MX_BLOCK = tl.cdiv(N, MXFP_BLOCK_SIZE)
# ...
mask_n_scale = offs_y_n_scale < N_MX_BLOCK

전역 차원 N으로 scale 블록 수를 계산했다. 하지만 epilogue에서 사용하는 차원은 로컬 yN이다.

After:

mask_n_scale = offs_y_n_scale < tl.cdiv(yN, MXFP_BLOCK_SIZE)

로컬 차원 yN을 사용하여 정확한 마스크를 생성한다.

Bug 2: Shared memory overflow

After (opt_flags.py):

if block_m == 64 and precision_config.c_mx_scale is not None \
    and rhs_dtype == FP4 and torch.cuda.get_device_capability()[0] >= 10:
    block_m = 128

fused activation과 MXFP8 downcast가 동시에 적용될 때 block_m=64이면 shared memory가 초과하는 문제를 block_m=128로 변경하여 해결했다.

왜 이게 좋은가

  1. 정확성: N vs yN 차이로 인한 잘못된 마스킹이 out-of-bounds 접근을 유발할 수 있었다
  2. 안정성: shared memory overflow는 커널 실행 실패로 직결되는 심각한 버그
  3. 테스트 추가: swiglu + mxfp8 downcast 조합에 대한 테스트 케이스 추가

정리

GPU 커널에서 전역 차원과 로컬 차원(타일 단위)을 혼동하는 것은 흔한 실수지만 찾기 어려운 버그를 만든다. MXFP 같은 양자화 포맷은 scale 블록 단위의 계산이 추가되어 이런 차원 관리가 더 복잡해진다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글

관련 포스트

PR Analysis 의 다른글