본문으로 건너뛰기

[Triton] matmul 커널 시그니처에 input microblock size 추가

들어가며

Triton의 matmul 커널에서 microscaled 연산(MXFP, NVFP4)을 사용할 때, microblock size(스케일 블록 크기)를 텐서 shape에서 추론하던 방식에서 명시적으로 PrecisionConfig에 지정하는 방식으로 변경하는 PR이다. NVFP4(block size 16)와 MXFP(block size 32)를 구분하기 위해 필요하다.

핵심 코드 분석

Before

a_microblock_size = None if a_scale is None else a.shape[-1] // a_scale.shape[-1]
b_microblock_size = None if b_scale is None else b.shape[-2] // b_scale.shape[-2]

microblock size를 텐서 shape 비율로 추론했다. 이 방식은 MXFP와 NVFP4를 구분할 수 없었다.

After

@dataclass
class PrecisionConfig:
    a_mx_scale: torch.Tensor | Tensor | None = None
    a_microblock_size: int | None = None  # 새로 추가
    b_mx_scale: torch.Tensor | Tensor | None = None
    b_microblock_size: int | None = None  # 새로 추가

# 사용 시 명시적 지정
pc = PrecisionConfig(
    b_mx_scale=b_scale,
    b_microblock_size=MXFP_BLOCK_SIZE.value,  # 32
)

검증 로직도 추가되었다:

assert b_scale is None or b_microblock_size is not None, (
    "precision_config.b_microblock_size is required when precision_config.b_mx_scale is set"
)

왜 이게 좋은가

  • 명시성: microblock size를 추론이 아닌 선언으로 표현하여 NVFP4(16)와 MXFP(32) 구분이 명확하다.
  • 검증 강화: scale이 있는데 microblock size가 없으면 즉시 오류를 발생시킨다.
  • NVFP4 x NVFP4 지원의 기반: 이 PR과 함께 nvfp4 x nvfp4 matmul 테스트 케이스가 추가되었다.

정리

API 설계에서 "추론 가능하더라도 명시적으로 요구하는 것"이 오류를 줄이는 좋은 방법임을 보여주는 PR이다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글