[triton] Gluon에 mma_scaled 연산 헬퍼 및 실행 테스트 추가
PR 링크: triton-lang/triton#8410 상태: Merged | 변경: +154 / -5
들어가며
NVIDIA Blackwell의 5세대 Tensor Core는 tcgen05_mma_scaled 명령을 통해 스케일이 적용된 행렬 곱셈을 하드웨어 수준에서 지원합니다. 이 PR은 Triton의 실험적 프론트엔드인 Gluon에 해당 연산의 헬퍼 함수, 스케일 레이아웃 생성 함수, 그리고 실행 정합성 테스트를 추가합니다.
핵심 코드 분석
스케일용 Tensor Memory 레이아웃 생성
@constexpr_function
def get_tmem_scales_reg_layout(M, N, shape, num_warps, ...):
"""Return a linear layout compatible with tmem scaled layout."""
assert len(shape) == 2, "expected a 2D tensor"
assert num_warps in [4, 8], "expected 4 or 8 warps"
# Scale은 32개 원소마다 1개이므로 K//32 크기의 텐서
테스트에서 스케일 적용 검증
# e8m0 스케일을 float32로 변환하여 레퍼런스 계산
def fp8e8m0_to_float32(scale):
scale = scale.view(torch.uint8)
scale = scale.to(torch.int32)
scale = scale << 23
scale = scale.view(torch.float32)
return scale
# 실행 결과 비교
A = A * a_scale_f32
B = B * b_scale_f32
ref = torch.matmul(A, B)
torch.testing.assert_close(out, ref, atol=1e-6, rtol=1e-6)
왜 이게 좋은가
- 하드웨어 직접 접근: Gluon을 통해 Blackwell의 scaled MMA를 직접 활용할 수 있어 FP8 학습/추론의 성능을 극대화합니다.
- 정합성 테스트: e8m0 스케일의 비트 조작을 포함한 완전한 end-to-end 테스트로 정확성을 보장합니다.
- 레이아웃 추상화: 복잡한 TMEM 스케일 레이아웃을
get_tmem_scales_reg_layout함수로 캡슐화하여 사용 편의성을 높였습니다.
정리
이 PR은 Gluon 사용자가 Blackwell의 scaled MMA를 활용할 수 있게 하는 기능 추가입니다. FP8 연산에서 블록 스케일링은 정밀도와 성능의 균형을 맞추는 핵심 기법이며, 하드웨어 가속을 직접 활용할 수 있다는 점에서 의미가 큽니다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
- [triton] Gluon tmem_load에서 Register Layout 자동 추론
- [Triton] 2CTA Block Scale MMA with tcgen05.cp — 두 CTA 협력 행렬 곱셈
- [triton] Blackwell GPU Cluster Launch Control 지원으로 Persistent Kernel 워크로드 밸런싱 구현
- [triton] AMD WMMA Utilization 개선: Unroll 제거와 상수 폴딩
- [triton] Consumer Blackwell(sm_120)에서 PTX Codegen Segfault 수정
PR Analysis 의 다른글
- 이전글 [Open WebUI] RecursiveFolder 컴포넌트 지연 로딩으로 페이지 로드 속도 개선
- 현재글 : [triton] Gluon에 mma_scaled 연산 헬퍼 및 실행 테스트 추가
- 다음글 [Triton] split_k에 m*n 제약 조건 추가
댓글