[triton] Gluon에서 3D Dot FMA 연산 노출
PR 링크: triton-lang/triton#9501 상태: Merged | 변경: +41 / -7
들어가며
Gluon의 dot_fma 함수는 FMA(Fused Multiply-Add) 기반의 소프트웨어 행렬 곱셈을 제공합니다. 기존에는 2D 텐서만 지원했는데, 이 PR은 3D(batched) 텐서도 지원하도록 확장합니다.
핵심 코드 분석
Before
def dot_fma(a, b, acc, _semantic=None):
M, N = acc.shape
K = a.shape[1]
if M * N * K > 2**19:
warnings.warn(f"Large dot FMA instruction size {M}x{N}x{K}")
After
def dot_fma(a, b, acc, _semantic=None):
assert len(acc.shape) == 2 or len(acc.shape) == 3
assert len(acc.shape) == len(a.shape) == len(b.shape)
unified_dot_shape = acc.shape + a.shape[-1:]
if math.prod(unified_dot_shape) > 2**19:
dot_name = "batched dot" if len(acc.shape) == 3 else "dot"
shape_str = "x".join([str(x) for x in unified_dot_shape])
warnings.warn(f"Large {dot_name} FMA instruction size {shape_str}")
FMA lowering 수정
// sizePerThread를 getContigPerThread 대신 layout에서 직접 가져옴
auto sizePerThread = getContigPerThread(dTensorTy); // Before
llvm::SmallVector<unsigned> sizePerThread{dLayout.getSizePerThread()}; // After
왜 이게 좋은가
- Batched 지원: Flash Attention 등에서 batch 차원이 있는 matmul을 FMA 경로로 실행할 수 있습니다.
- 일반화된 검증: 2D/3D를 통합하여 shape 검증과 크기 경고를 처리합니다.
- 올바른 sizePerThread: 3D 레이아웃에서 contig per thread가 아닌 전체 size per thread를 사용하여 정확한 값을 얻습니다.
정리
FMA dot을 3D 텐서로 확장한 작은 변경이지만, batched matmul 파이프라인에서 중요한 빌딩 블록입니다. lowering 수정도 함께 이루어져 end-to-end 정합성이 보장됩니다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
- [triton] Multi-CTA 예제에서 Program ID를 Shared Memory에 저장하여 재계산 방지
- [triton] Blackwell GPU Cluster Launch Control 지원으로 Persistent Kernel 워크로드 밸런싱 구현
- [triton] AMD TDM의 Partition-Aware 분할 및 다중 Intrinsic 지원
- [triton] AMD GFX9 Async Copy에서 Shared Memory 순서 버그 수정
- [triton] GSan AxisInfo 기반 Shadow Update 중복 제거로 2~10배 성능 향상
PR Analysis 의 다른글
- 이전글 [Loki] TSDBIndex.GetChunkRefs에서 불필요한 라벨 조회 제거
- 현재글 : [triton] Gluon에서 3D Dot FMA 연산 노출
- 다음글 [triton] Backend별 global_scratch_alloc 할당 통합
댓글