[Triton] Frontend에서 scaled batched matrix multiply 지원
PR 링크: triton-lang/triton#9000 상태: Merged | 변경: +147 / -5
들어가며
Triton의 dot_scaled 연산은 FP8/FP4 등의 scaled dot product를 수행한다. 이전 PR(#8564)에서 scale 텐서의 차원 검증이 추가되었지만, 전체 shape를 비교했기 때문에 BMM(Batched Matrix Multiply) 피연산자에서는 검증이 실패했다. 이 PR은 마지막 2차원만 비교하도록 수정하여 BMM을 지원한다.
핵심 코드 분석
Before: 전체 shape 비교
LogicalResult DotScaledOp::verify() {
auto aShape = this->getA().getType().getShape();
int64_t rank = aShape.size();
// rank가 3 이상이면 batch 차원 때문에 검증 실패
auto k = aShape[rank - 1];
// ...
}
scale shape 검증 에러 메시지: "lhs_scale must be a tensor of shape [32, 2]. Got ['32', '4']"
After: 마지막 2차원만 비교 + rank 검증
LogicalResult DotScaledOp::verify() {
auto aShape = this->getA().getType().getShape();
int64_t rank = aShape.size();
if (rank < 2)
return this->emitError("operands must be at least 2D");
// 마지막 2차원 기준으로 k, scale shape 검증
auto k = aShape[rank - 1];
// ...
}
scale shape 검증 에러 메시지: "lhs_scale must be a tensor of shape [..., 32, 2]. Got ['32', '4']"
Batched MXFP matmul 테스트
@triton.jit
def batched_mxfp_matmul(a_ptr, b_ptr, output_ptr, a_scale, b_scale,
M, N, K, ...,
BLOCK_BATCH_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr):
offs_batch = (batch_id * BLOCK_BATCH_SIZE
+ tl.arange(0, BLOCK_BATCH_SIZE)) % BATCH_SIZE
# 3D 텐서: [BLOCK_BATCH_SIZE, BLOCK_M, BLOCK_K]
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
scale_a = tl.load(a_scale_ptr)
scale_b = tl.load(b_scale_ptr)
accumulator = tl.dot_scaled(a, scale_a, "e5m2",
b, scale_b, "e5m2", accumulator)
왜 이게 좋은가
- BMM 지원: batch 차원이 있는 3D+ 텐서에서도
dot_scaled를 사용할 수 있다. - 기존 검증 유지: 마지막 2차원의 scale shape 검증은 그대로 유지하여 안전성을 보장한다.
- 에러 메시지 개선:
[32, 2]대신[..., 32, 2]로 표시하여 batch 차원이 있을 수 있음을 명시한다. - rank 검증 추가: 1D 텐서 등 비정상 입력을 조기에 차단한다.
정리
이 PR은 dot_scaled의 scale shape 검증을 전체 shape가 아닌 마지막 2차원 기준으로 변경하여, BMM 피연산자를 지원한다. 다양한 batch 크기와 블록 크기 조합에 대한 end-to-end 테스트도 포함한다.
참고 자료
- triton-lang/triton#9000
- triton-lang/triton#8564 (원본 검증 추가 PR)
이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Triton] AMD scf.if else 분기 누락 버그 수정 — deduceMinCountBetweeOps
- 현재글 : [Triton] Frontend에서 scaled batched matrix multiply 지원
- 다음글 [Triton] Gluon 검증 로직을 C++ verifier로 이동 — 차원 축소 로드 지원
댓글