본문으로 건너뛰기

[triton] AMD GPU에서 Block Scaled Matmul 지원 추가

PR 링크: triton-lang/triton#8099 상태: Merged | 변경: +289 / -53

들어가며

Block scaled matrix multiplication은 FP4/FP8 같은 저정밀 포맷에서 스케일 팩터를 블록 단위로 적용하는 연산입니다. 기존에는 NVIDIA Blackwell GPU만 지원했으나, 이 PR은 AMD CDNA4(gfx950) 아키텍처에서도 동작하도록 확장합니다. 특히 AMD의 MFMA scaled 명령어에 맞는 스케일 프리셔플링 로직과 별도의 matmul 커널을 추가합니다.

핵심 코드 분석

Before

# supports_block_scaling은 NVIDIA만 지원
def supports_block_scaling():
    return is_cuda() and torch.cuda.get_device_capability()[0] == 10

After

def is_hip_cdna4():
    target = triton.runtime.driver.active.get_current_target()
    return target is not None and target.backend == 'hip' and target.arch == 'gfx950'

def supports_block_scaling():
    return (is_cuda() and torch.cuda.get_device_capability()[0] == 10) or is_hip_cdna4()

AMD용 스케일 프리셔플링 문서화

# Scale preshuffling on AMD GPUs
# MFMA 16x16x128: 4 threads along K, 16 along M/N
# Packing order: mfma_op_0, mfma_op_2, mfma_op_1, mfma_op_3
#            K = 128       K = 128
#        +------------+ +------------+
#    M=16|  MFMA op 0 | |  MFMA op 1 |
#        +------------+ +------------+
#    M=16|  MFMA op 2 | |  MFMA op 3 |
#        +------------+ +------------+

왜 이게 좋은가

  1. 크로스 벤더 지원: 동일한 튜토리얼 코드로 NVIDIA와 AMD GPU 모두에서 block scaled matmul을 실행할 수 있습니다.
  2. 문서 이전: 테스트 코드에 있던 스케일 셔플링 설명을 공식 튜토리얼로 옮겨 접근성을 높였습니다.
  3. 아키텍처 특화 커널: AMD CDNA4의 MFMA scaled 명령어에 최적화된 별도 커널(block_scaled_matmul_kernel_cdna4)을 제공합니다.

정리

이 PR은 Triton의 block scaled matmul 기능을 AMD CDNA4 GPU로 확장하면서, 스케일 프리셔플링의 원리를 잘 문서화한 좋은 사례입니다. 특히 벤더별 아키텍처 차이(Tensor Core vs MFMA)를 추상화하면서도 각각에 최적화된 커널을 제공하는 접근이 인상적입니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글