본문으로 건너뛰기

[triton] Gluon에서 3D Dot FMA 연산 노출

PR 링크: triton-lang/triton#9501 상태: Merged | 변경: +41 / -7

들어가며

Gluon의 dot_fma 함수는 FMA(Fused Multiply-Add) 기반의 소프트웨어 행렬 곱셈을 제공합니다. 기존에는 2D 텐서만 지원했는데, 이 PR은 3D(batched) 텐서도 지원하도록 확장합니다.

핵심 코드 분석

Before

def dot_fma(a, b, acc, _semantic=None):
    M, N = acc.shape
    K = a.shape[1]
    if M * N * K > 2**19:
        warnings.warn(f"Large dot FMA instruction size {M}x{N}x{K}")

After

def dot_fma(a, b, acc, _semantic=None):
    assert len(acc.shape) == 2 or len(acc.shape) == 3
    assert len(acc.shape) == len(a.shape) == len(b.shape)
    
    unified_dot_shape = acc.shape + a.shape[-1:]
    if math.prod(unified_dot_shape) > 2**19:
        dot_name = "batched dot" if len(acc.shape) == 3 else "dot"
        shape_str = "x".join([str(x) for x in unified_dot_shape])
        warnings.warn(f"Large {dot_name} FMA instruction size {shape_str}")

FMA lowering 수정

// sizePerThread를 getContigPerThread 대신 layout에서 직접 가져옴
auto sizePerThread = getContigPerThread(dTensorTy);  // Before
llvm::SmallVector<unsigned> sizePerThread{dLayout.getSizePerThread()};  // After

왜 이게 좋은가

  1. Batched 지원: Flash Attention 등에서 batch 차원이 있는 matmul을 FMA 경로로 실행할 수 있습니다.
  2. 일반화된 검증: 2D/3D를 통합하여 shape 검증과 크기 경고를 처리합니다.
  3. 올바른 sizePerThread: 3D 레이아웃에서 contig per thread가 아닌 전체 size per thread를 사용하여 정확한 값을 얻습니다.

정리

FMA dot을 3D 텐서로 확장한 작은 변경이지만, batched matmul 파이프라인에서 중요한 빌딩 블록입니다. lowering 수정도 함께 이루어져 end-to-end 정합성이 보장됩니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글