본문으로 건너뛰기

[Triton] Frontend에서 scaled batched matrix multiply 지원

PR 링크: triton-lang/triton#9000 상태: Merged | 변경: +147 / -5

들어가며

Triton의 dot_scaled 연산은 FP8/FP4 등의 scaled dot product를 수행한다. 이전 PR(#8564)에서 scale 텐서의 차원 검증이 추가되었지만, 전체 shape를 비교했기 때문에 BMM(Batched Matrix Multiply) 피연산자에서는 검증이 실패했다. 이 PR은 마지막 2차원만 비교하도록 수정하여 BMM을 지원한다.

핵심 코드 분석

Before: 전체 shape 비교

LogicalResult DotScaledOp::verify() {
  auto aShape = this->getA().getType().getShape();
  int64_t rank = aShape.size();
  // rank가 3 이상이면 batch 차원 때문에 검증 실패
  auto k = aShape[rank - 1];
  // ...
}

scale shape 검증 에러 메시지: "lhs_scale must be a tensor of shape [32, 2]. Got ['32', '4']"

After: 마지막 2차원만 비교 + rank 검증

LogicalResult DotScaledOp::verify() {
  auto aShape = this->getA().getType().getShape();
  int64_t rank = aShape.size();
  if (rank < 2)
    return this->emitError("operands must be at least 2D");
  // 마지막 2차원 기준으로 k, scale shape 검증
  auto k = aShape[rank - 1];
  // ...
}

scale shape 검증 에러 메시지: "lhs_scale must be a tensor of shape [..., 32, 2]. Got ['32', '4']"

Batched MXFP matmul 테스트

@triton.jit
def batched_mxfp_matmul(a_ptr, b_ptr, output_ptr, a_scale, b_scale,
                         M, N, K, ...,
                         BLOCK_BATCH_SIZE: tl.constexpr,
                         BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
                         BLOCK_K: tl.constexpr):
    offs_batch = (batch_id * BLOCK_BATCH_SIZE
                  + tl.arange(0, BLOCK_BATCH_SIZE)) % BATCH_SIZE
    # 3D 텐서: [BLOCK_BATCH_SIZE, BLOCK_M, BLOCK_K]
    a = tl.load(a_ptrs)
    b = tl.load(b_ptrs)
    scale_a = tl.load(a_scale_ptr)
    scale_b = tl.load(b_scale_ptr)
    accumulator = tl.dot_scaled(a, scale_a, "e5m2",
                                b, scale_b, "e5m2", accumulator)

왜 이게 좋은가

  1. BMM 지원: batch 차원이 있는 3D+ 텐서에서도 dot_scaled를 사용할 수 있다.
  2. 기존 검증 유지: 마지막 2차원의 scale shape 검증은 그대로 유지하여 안전성을 보장한다.
  3. 에러 메시지 개선: [32, 2] 대신 [..., 32, 2]로 표시하여 batch 차원이 있을 수 있음을 명시한다.
  4. rank 검증 추가: 1D 텐서 등 비정상 입력을 조기에 차단한다.

정리

이 PR은 dot_scaled의 scale shape 검증을 전체 shape가 아닌 마지막 2차원 기준으로 변경하여, BMM 피연산자를 지원한다. 다양한 batch 크기와 블록 크기 조합에 대한 end-to-end 테스트도 포함한다.

참고 자료


이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.

댓글

관련 포스트

PR Analysis 의 다른글