본문으로 건너뛰기

[Triton] matmul 커널에 nvfp4 x nvfp4, mxfp4 x mxfp4 지원 추가

들어가며

4비트 부동소수점(FP4)은 LLM 추론에서 메모리 대역폭과 계산 비용을 크게 줄이는 핵심 기술이다. 이 PR은 Triton의 matmul 커널에서 양쪽 피연산자가 모두 FP4인 경우(nvfp4 x nvfp4, mxfp4 x mxfp4)를 지원하도록 확장하고, quantization 유틸리티에 NVFP4 지원을 추가한다.

핵심 코드 분석

DType 클래스 확장

# Before
self.has_mx_scale = dtype_str.startswith("mx")
self.torch_dtype = to_torch_dtype(dtype_str.strip("mx"))

# After
self.is_nvfp4 = dtype_str == "nvfp4_e2m1"
self.has_mx_scale = dtype_str.startswith("mx") or self.is_nvfp4
if dtype_str in {"float4_e2m1", "mxfloat4_e2m1", "nvfp4_e2m1"}:
    self.torch_dtype = torch.uint8
self.scale_dtype = torch.float8_e4m3fn if self.is_nvfp4 else torch.uint8 if self.has_mx_scale else None
self.microblock_size = NVFP_BLOCK_SIZE.value if self.is_nvfp4 else MXFP_BLOCK_SIZE.value if ...

NVFP4는 MXFP와 달리 float8_e4m3fn scale을 사용하고, microblock size는 16이다(MXFP는 32).

출력 FP4 지원

# FP4 출력 시 storage shape 조정 (2개 값이 1바이트에 패킹)
c_storage_shape = c_shape[:-1] + (c_shape[-1] // 2,) if c_dtype.has_mx_scale and c_dtype.is_mxfloat4 else c_shape
c = torch.empty(c_storage_shape, dtype=c_dtype.torch_dtype, device=device)

A-scale swizzling 활성화

# Before: mxfloat8 전용
scale_hbm_swizzling = layout.make_default_matmul_mxfp8_act_scale_layout if a_hbm_swizzling else None

# After: 모든 microscaled 활성화에 적용
scale_hbm_swizzling = layout.make_default_matmul_mx_act_scale_layout if a_hbm_swizzling else None

왜 이게 좋은가

  • FP4 x FP4 matmul: 양쪽 모두 4비트인 matmul로 메모리 사용량과 연산량을 대폭 줄인다.
  • NVFP4/MXFP4 구분: scale 타입과 block size 차이를 명확히 분리하여 두 표준 모두 지원한다.
  • 광범위한 테스트: 다양한 shape, mode(plain/ragged/batched), split_k, swizzling 조합에 대한 테스트가 추가되었다.

정리

+522/-202 변경으로, Blackwell GPU에서의 FP4 matmul 지원을 완성한다. LLM 추론 성능에 직접적으로 기여하는 중요한 기능 추가다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글