[Triton] matmul 커널에 nvfp4 x nvfp4, mxfp4 x mxfp4 지원 추가
들어가며
4비트 부동소수점(FP4)은 LLM 추론에서 메모리 대역폭과 계산 비용을 크게 줄이는 핵심 기술이다. 이 PR은 Triton의 matmul 커널에서 양쪽 피연산자가 모두 FP4인 경우(nvfp4 x nvfp4, mxfp4 x mxfp4)를 지원하도록 확장하고, quantization 유틸리티에 NVFP4 지원을 추가한다.
핵심 코드 분석
DType 클래스 확장
# Before
self.has_mx_scale = dtype_str.startswith("mx")
self.torch_dtype = to_torch_dtype(dtype_str.strip("mx"))
# After
self.is_nvfp4 = dtype_str == "nvfp4_e2m1"
self.has_mx_scale = dtype_str.startswith("mx") or self.is_nvfp4
if dtype_str in {"float4_e2m1", "mxfloat4_e2m1", "nvfp4_e2m1"}:
self.torch_dtype = torch.uint8
self.scale_dtype = torch.float8_e4m3fn if self.is_nvfp4 else torch.uint8 if self.has_mx_scale else None
self.microblock_size = NVFP_BLOCK_SIZE.value if self.is_nvfp4 else MXFP_BLOCK_SIZE.value if ...
NVFP4는 MXFP와 달리 float8_e4m3fn scale을 사용하고, microblock size는 16이다(MXFP는 32).
출력 FP4 지원
# FP4 출력 시 storage shape 조정 (2개 값이 1바이트에 패킹)
c_storage_shape = c_shape[:-1] + (c_shape[-1] // 2,) if c_dtype.has_mx_scale and c_dtype.is_mxfloat4 else c_shape
c = torch.empty(c_storage_shape, dtype=c_dtype.torch_dtype, device=device)
A-scale swizzling 활성화
# Before: mxfloat8 전용
scale_hbm_swizzling = layout.make_default_matmul_mxfp8_act_scale_layout if a_hbm_swizzling else None
# After: 모든 microscaled 활성화에 적용
scale_hbm_swizzling = layout.make_default_matmul_mx_act_scale_layout if a_hbm_swizzling else None
왜 이게 좋은가
- FP4 x FP4 matmul: 양쪽 모두 4비트인 matmul로 메모리 사용량과 연산량을 대폭 줄인다.
- NVFP4/MXFP4 구분: scale 타입과 block size 차이를 명확히 분리하여 두 표준 모두 지원한다.
- 광범위한 테스트: 다양한 shape, mode(plain/ragged/batched), split_k, swizzling 조합에 대한 테스트가 추가되었다.
정리
+522/-202 변경으로, Blackwell GPU에서의 FP4 matmul 지원을 완성한다. LLM 추론 성능에 직접적으로 기여하는 중요한 기능 추가다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글