본문으로 건너뛰기

[triton] Gluon에 mma_scaled 연산 헬퍼 및 실행 테스트 추가

PR 링크: triton-lang/triton#8410 상태: Merged | 변경: +154 / -5

들어가며

NVIDIA Blackwell의 5세대 Tensor Core는 tcgen05_mma_scaled 명령을 통해 스케일이 적용된 행렬 곱셈을 하드웨어 수준에서 지원합니다. 이 PR은 Triton의 실험적 프론트엔드인 Gluon에 해당 연산의 헬퍼 함수, 스케일 레이아웃 생성 함수, 그리고 실행 정합성 테스트를 추가합니다.

핵심 코드 분석

스케일용 Tensor Memory 레이아웃 생성

@constexpr_function
def get_tmem_scales_reg_layout(M, N, shape, num_warps, ...):
    """Return a linear layout compatible with tmem scaled layout."""
    assert len(shape) == 2, "expected a 2D tensor"
    assert num_warps in [4, 8], "expected 4 or 8 warps"
    # Scale은 32개 원소마다 1개이므로 K//32 크기의 텐서

테스트에서 스케일 적용 검증

# e8m0 스케일을 float32로 변환하여 레퍼런스 계산
def fp8e8m0_to_float32(scale):
    scale = scale.view(torch.uint8)
    scale = scale.to(torch.int32)
    scale = scale << 23
    scale = scale.view(torch.float32)
    return scale

# 실행 결과 비교
A = A * a_scale_f32
B = B * b_scale_f32
ref = torch.matmul(A, B)
torch.testing.assert_close(out, ref, atol=1e-6, rtol=1e-6)

왜 이게 좋은가

  1. 하드웨어 직접 접근: Gluon을 통해 Blackwell의 scaled MMA를 직접 활용할 수 있어 FP8 학습/추론의 성능을 극대화합니다.
  2. 정합성 테스트: e8m0 스케일의 비트 조작을 포함한 완전한 end-to-end 테스트로 정확성을 보장합니다.
  3. 레이아웃 추상화: 복잡한 TMEM 스케일 레이아웃을 get_tmem_scales_reg_layout 함수로 캡슐화하여 사용 편의성을 높였습니다.

정리

이 PR은 Gluon 사용자가 Blackwell의 scaled MMA를 활용할 수 있게 하는 기능 추가입니다. FP8 연산에서 블록 스케일링은 정밀도와 성능의 균형을 맞추는 핵심 기법이며, 하드웨어 가속을 직접 활용할 수 있다는 점에서 의미가 큽니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글