본문으로 건너뛰기

[triton] Triton에서 i8 행렬 곱셈 최적화: 레지스터 압력 감소 및 성능 향상

PR 링크: triton-lang/triton#10533 상태: Merged | 변경: +331 / -224

들어가며

Triton은 GPU에서 고성능 딥러닝 연산을 쉽게 작성할 수 있도록 돕는 컴파일러입니다. 특히, 낮은 정밀도(low-precision) 연산은 메모리 대역폭을 절약하고 연산 속도를 높이는 데 중요한 역할을 합니다. 최근 Triton PR (#10533)은 8비트 정수(i8)를 사용하는 행렬 곱셈(MMA, Matrix Multiply-Accumulate) 연산의 성능을 개선하는 데 초점을 맞추었습니다. 이 PR은 기존의 i8 행렬 곱셈 구현에서 발생하는 레지스터 압력(register pressure)을 줄이고, 더 효율적인 연산 흐름을 만들어 전반적인 성능을 향상시키는 것을 목표로 합니다.

이 글에서는 해당 PR의 코드 변경 사항을 분석하고, 어떤 부분이 어떻게 개선되었는지, 그리고 이러한 최적화가 왜 효과적인지에 대해 자세히 알아보겠습니다.

코드 분석

lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp

이 파일은 Triton IR(Intermediate Representation)을 변환하고 최적화하는 로직을 포함하고 있습니다. i8 행렬 곱셈의 최적화는 주로 getMmaEmulationTileShape 함수와 emitI8DotDecomposition 함수의 변경을 통해 이루어졌습니다.

1. getMmaEmulationTileShape 함수의 개선

이 함수는 MMA 연산을 위한 최적의 타일(tile) 크기를 결정하는 역할을 합니다. 기존 코드에서는 i8 MMA 연산을 위한 타일 크기 결정 로직이 다소 단순했습니다. 새로운 PR에서는 i8 MMA 연산에 대한 타일 크기 결정 로직이 더 정교해졌습니다.

Before:

-  if (supportsI8DotDecomposition(rewriter, accElem) && (k % kI8MmaK) == 0) {
-    int64_t tileM = std::min<int64_t>(16 * numWarps, m);
-    int64_t tileN = std::min<int64_t>(8 * numWarps, n);
-    if (canUseI8MmaTile(tileM, tileN, numWarps))
-      tile = {tileM, tileN};
-  }
-  if (directShared) {
-    int64_t widerN = std::min<int64_t>(16 * numWarps, n);
-    if (widerN > tile.second && canUseI8MmaTile(tile.first, widerN, numWarps))
-      tile.second = widerN;
-  }

After:

+  if (!supportsI8DotDecomposition(rewriter, accElem) || (k % kI8MmaK) != 0)
+    return tile;
+
+  // Cap the MMAv2 accumulator at 32 registers per thread.
+  int64_t maxTileArea = 32 * 32 * numWarps / (accElem.getWidth() == 64 ? 2 : 1);
+  for (int64_t tileM = kI8MmaM; tileM <= m; tileM *= 2) {
+    if ((m % tileM) != 0)
+      continue;
+    for (int64_t tileN = kI8MmaN; tileN <= n; tileN *= 2) {
+      if ((n % tileN) == 0 && tileM <= 2 * tileN && tileN <= 2 * tileM &&
+          canUseI8MmaTile(tileM, tileN, numWarps) &&
+          tileM * tileN <= maxTileArea &&
+          tileM * tileN > tile.first * tile.second)
+        tile = {tileM, tileN};
+    }
+  }

설명:

  • 기존에는 numWarps에 비례하여 tileMtileN을 단순하게 결정했습니다. 또한 directShared라는 조건에 따라 tileN을 확장하려는 시도가 있었습니다.
  • 개선된 코드에서는 i8 MMA 연산이 지원되고 k 차원이 kI8MmaK (32)의 배수인 경우에만 i8 MMA 관련 로직을 수행합니다.
  • maxTileArea를 계산하여 타일의 총 면적이 특정 임계값을 넘지 않도록 제한합니다. 이는 레지스터 사용량을 관리하기 위함입니다. (MMAv2는 스레드당 최대 32개의 레지스터를 사용합니다.)
  • tileMtileN을 2의 거듭제곱으로 증가시키면서 가능한 타일 조합을 탐색합니다. tileM <= 2 * tileNtileN <= 2 * tileM 조건은 MMA 연산의 효율성을 높이는 데 도움이 되는 종횡비(aspect ratio)를 유지하기 위한 휴리스틱입니다.
  • 가장 중요한 것은 tileM * tileN > tile.first * tile.second 조건을 통해 이전까지 찾은 최적의 타일보다 더 큰 타일을 찾으면 업데이트한다는 점입니다. 이는 더 큰 타일을 사용하여 연산량을 줄이고 잠재적으로 성능을 향상시킬 수 있음을 의미합니다.

2. emitI8DotDecomposition 함수의 재구성

이 함수는 i8 행렬 곱셈을 실제 하드웨어 MMA 명령어로 변환하는 핵심 로직을 담당합니다. 이전에는 tryEmitI8DotDecomposition이라는 이름으로, i8 MMA 연산이 가능한지 확인하고 가능하면 변환을 시도하는 방식이었습니다. 새로운 PR에서는 emitI8DotDecomposition으로 이름이 변경되었고, 로직이 크게 변경되었습니다.

Before (핵심 로직 일부):

-Value tryEmitI8DotDecomposition(PatternRewriter &rewriter, Location loc,
-                                Value aPayload, Value bPayload,
-                                Attribute accLayout, IntegerType accElem,
-                                int numWarps) {
-  auto aPayloadTy = cast<RankedTensorType>(aPayload.getType());
-  auto bPayloadTy = cast<RankedTensorType>(bPayload.getType());
-  auto aShape = aPayloadTy.getShape();
-  auto bShape = bPayloadTy.getShape();
-  int64_t m = aShape[0];
-  int64_t k = aShape[1];
-  int64_t n = bShape[1];
-  if (bShape[0] != k || (k % kI8MmaK) != 0 ||
-      !supportsI8DotDecomposition(rewriter, accElem)) // Check if i8 MMA is supported
-    return Value();
-  if (!canUseI8MmaTile(m, n, numWarps)) // Check if tile size is feasible
-    return Value();
-  auto aElem = cast<IntegerType>(aPayloadTy.getElementType());
-  auto bElem = cast<IntegerType>(bPayloadTy.getElementType());
-  assert((aElem.getWidth() % 8) == 0 && (bElem.getWidth() % 8) == 0);
-  auto i8Ty = rewriter.getI8Type();
-  auto i32Ty = rewriter.getI32Type();
-  auto mmaLayout = getI8MmaAccumulatorEncoding(
-      rewriter, SmallVector<int64_t>{m, n}, accLayout, numWarps);
-  auto aDotLayout = ttg::DotOperandEncodingAttr::get(ctx, 0, mmaLayout, i8Ty);
-  auto bDotLayout = ttg::DotOperandEncodingAttr::get(ctx, 1, mmaLayout, i8Ty);
-  auto accMmaTy = RankedTensorType::get({m, n}, i32Ty, mmaLayout);
-  auto aMmaTy = aPayloadTy.cloneWithEncoding(aDotLayout);
-  auto bMmaTy = bPayloadTy.cloneWithEncoding(bDotLayout);
-  aPayload = ttg::ConvertLayoutOp::create(rewriter, loc, aMmaTy, aPayload);
-  bPayload = ttg::ConvertLayoutOp::create(rewriter, loc, bMmaTy, bPayload);
-  auto workElem = cast<IntegerType>(accMmaTy.getElementType());
-  assert(workElem == (accElem.getWidth() == 64 ? accElem : i32Ty));
-  // ... (rest of the logic for extracting limbs and emitting dots)

After (핵심 로직):

+Value emitI8DotDecomposition(PatternRewriter &rewriter, Location loc,
+                             Value aPayload, Value bPayload,
+                             IntegerType accElem, Value initialAccumulator) {
+  auto aPayloadTy = cast<RankedTensorType>(aPayload.getType());
+  auto bPayloadTy = cast<RankedTensorType>(bPayload.getType());
+  auto workMmaTy = cast<RankedTensorType>(initialAccumulator.getType());
+  assert(aPayloadTy.getRank() == 2 && bPayloadTy.getRank() == 2 &&
+         workMmaTy.getRank() == 2);
+  auto aShape = aPayloadTy.getShape();
+  auto bShape = bPayloadTy.getShape();
+  auto workShape = workMmaTy.getShape();
+  assert(aShape[1] == bShape[0] && (aShape[1] % kI8MmaK) == 0);
+  assert(aShape[0] == workShape[0] && bShape[1] == workShape[1]);
+  auto aElem = cast<IntegerType>(aPayloadTy.getElementType());
+  auto bElem = cast<IntegerType>(bPayloadTy.getElementType());
+  assert((aElem.getWidth() % 8) == 0 && (bElem.getWidth() % 8) == 0);
+  auto i8Ty = rewriter.getI8Type();
+  auto i32Ty = rewriter.getI32Type();
+  auto mmaLayout = cast<ttg::NvidiaMmaEncodingAttr>(workMmaTy.getEncoding());
+  auto aDotLayout = ttg::DotOperandEncodingAttr::get(ctx, 0, mmaLayout, i8Ty);
+  auto bDotLayout = ttg::DotOperandEncodingAttr::get(ctx, 1, mmaLayout, i8Ty);
+  auto aMmaTy = aPayloadTy.cloneWithEncoding(aDotLayout);
+  auto bMmaTy = bPayloadTy.cloneWithEncoding(bDotLayout);
+  aPayload = ttg::ConvertLayoutOp::create(rewriter, loc, aMmaTy, aPayload);
+  bPayload = ttg::ConvertLayoutOp::create(rewriter, loc, bMmaTy, bPayload);
+  auto workElem = cast<IntegerType>(workMmaTy.getElementType());
+  assert(workElem == (accElem.getWidth() == 64 ? accElem : i32Ty));
+
+  // Peel register repetitions outside each native IMMA fragment from the
+  // largest stride down, then reassemble them in the inverse order.
+  SmallVector<std::pair<unsigned, int64_t>> fragmentSplits;
+  auto mmaLinearLayout = mmaLayout.toLinearLayout(workShape);
+  auto kRegister = StringAttr::get(ctx, "register");
+  const auto &registerBases = mmaLinearLayout.getBases().lookup(kRegister);
+  for (const auto &basis : llvm::reverse(registerBases)) {
+    if (basis[0] >= kI8MmaM && basis[1] == 0)
+      fragmentSplits.emplace_back(0, basis[0]);
+    else if (basis[0] == 0 && basis[1] >= kI8MmaN)
+      fragmentSplits.emplace_back(1, basis[1]);
+  }
+
+  auto splitAtRegisterBasis = [&](Value tensor, unsigned axis,
+                                  int64_t stride) -> std::pair<Value, Value> {
+    // ... (implementation details for splitting tensor)
+  };
+
+  auto joinAtRegisterBasis = [&](Value lhs, Value rhs, unsigned axis,
+                                 int64_t stride) -> Value {
+    // ... (implementation details for joining tensor)
+  };
+
+  auto extractLimb = [&](Value payload, ttg::DotOperandEncodingAttr layout,
+                         int64_t limb) -> Value {
+    // ... (implementation details for extracting i8 limbs)
+  };
+
+  auto emitFragments = [&](auto &&self, Value a, Value b, Value accumulator,
+                           unsigned splitIdx) -> Value {
+    if (splitIdx < fragmentSplits.size()) {
+      // ... (recursive calls to split and combine)
+    }
+
+    auto tileWorkTy = cast<RankedTensorType>(accumulator.getType())
+                          .cloneWithEncoding(mmaLayout);
+    auto accMmaTy = tileWorkTy.clone(i32Ty);
+    accumulator = ttg::ConvertLayoutOp::create(rewriter, loc, tileWorkTy, accumulator);
+
+    // ... (logic to perform i8 MMA operations and accumulate results)
+    return accumulator;
+  };
+
+  return emitFragments(emitFragments, aPayload, bPayload, initialAccumulator, 0);
+}

설명:

  • 입력 변경: emitI8DotDecomposition 함수는 이제 initialAccumulator를 입력으로 받습니다. 이는 기존의 누적기(accumulator)를 재사용하여 연산을 수행함을 의미하며, 이는 불필요한 초기화 연산을 줄여 성능에 도움이 될 수 있습니다.
  • 레지스터 압력 감소: 새로운 로직의 핵심은 fragmentSplitssplitAtRegisterBasis, joinAtRegisterBasis 함수를 사용하는 것입니다. 이들은 MMA 연산의 내부 구조를 분석하여, 큰 덩어리의 데이터를 한 번에 처리하는 대신 작은 조각(fragment)으로 나누어 처리하고, 이 조각들을 레지스터에 효율적으로 배치합니다. 특히, Peel register repetitions outside each native IMMA fragment from the largest stride down, then reassemble them in the inverse order.라는 주석은 이러한 접근 방식의 핵심을 잘 보여줍니다. 데이터를 분해하고 재조립하는 과정에서 레지스터 사용량을 최적화하여 레지스터 압력을 줄입니다.
  • 연산 순서 변경: emitFragments 함수 내에서 i8 MMA 연산(DotI8Op)을 수행하고 결과를 누적하는 방식이 변경되었습니다. 이전에는 각 바이트 대각선(byte diagonal)을 계산하는 방식이었다면, 새로운 방식은 분할된 조각들을 재귀적으로 처리하며 MMA 연산을 수행합니다. 이는 연산의 세분화를 통해 레지스터 할당을 더 유연하게 만들고, 필요한 시점에만 데이터를 로드하여 레지스터 압력을 낮추는 효과를 가져옵니다.
  • 명시적인 변환: ttg::ConvertLayoutOp을 사용하여 데이터 레이아웃을 명시적으로 변환하는 부분이 추가되었습니다. 이는 i8 데이터를 MMA 연산에 적합한 형식으로 변환하는 과정을 명확히 하고, 컴파일러가 최적화할 수 있는 기회를 제공합니다.

왜 이게 좋은가?

이 PR의 주요 개선 사항은 다음과 같습니다.

  1. 레지스터 압력 감소: i8 행렬 곱셈은 종종 많은 중간 결과를 생성합니다. 이전 구현은 이러한 중간 결과들을 레지스터에 유지하려다 레지스터 압력이 높아져 성능 저하를 일으킬 수 있었습니다. 새로운 emitI8DotDecomposition 로직은 데이터를 더 작은 조각으로 나누고, 필요한 데이터만 레지스터에 로드하며, 연산이 끝난 중간 결과는 즉시 다음 연산에 사용하거나 메모리에 쓰는 방식으로 레지스터 압력을 효과적으로 줄입니다. 이는 GPU의 제한된 레지스터 자원을 더 효율적으로 사용하게 하여 성능 향상으로 이어집니다.
  2. 타일 크기 최적화: getMmaEmulationTileShape 함수에서 실험적으로 결정된 휴리스틱을 사용하여 더 큰 타일 크기를 탐색하고, maxTileArea 제한을 통해 레지스터 사용량을 제어합니다. 이는 특정 하드웨어 아키텍처에 더 잘 맞는 타일 크기를 찾아 성능을 극대화하는 데 도움이 됩니다.
  3. 연산 재구성: 데이터를 분해하고 재조립하는 과정은 복잡하지만, 이를 통해 GPU 하드웨어의 MMA 명령어를 더 효율적으로 활용할 수 있습니다. 특히, 레지스터에 데이터를 배치하는 방식을 최적화함으로써 데이터 이동을 최소화하고 연산 속도를 높입니다.

이러한 최적화는 직접적인 성능 수치로 나타날 수 있습니다. 비록 이 PR 설명에는 구체적인 성능 향상 수치가 명시되어 있지 않지만, 일반적으로 레지스터 압력 감소는 GPU 커널의 성능을 크게 향상시키는 요인입니다. 특히 i8과 같이 낮은 정밀도를 사용하는 연산에서는 더 많은 데이터를 처리할 수 있게 되어 메모리 대역폭 병목 현상을 완화하는 데도 기여할 수 있습니다.

일반적 교훈:

  • 레지스터 압력 관리는 GPU 최적화의 핵심입니다. 복잡한 연산일수록 중간 결과를 효율적으로 관리하는 것이 중요합니다. 데이터를 작게 분할하고, 필요한 데이터만 로드하며, 연산 순서를 최적화하는 기법을 고려해야 합니다.
  • 하드웨어 특성을 이해하고 활용해야 합니다. MMA 명령어와 같은 하드웨어 기능을 최대한 활용하기 위해 데이터 레이아웃과 연산 순서를 조정하는 것이 중요합니다.
  • 실험적 휴리스틱은 유용할 수 있습니다. 특히 타일 크기나 종횡비와 같이 성능에 민감한 파라미터는 실험을 통해 최적값을 찾는 것이 효과적일 수 있습니다.

결론

이번 Triton PR은 i8 행렬 곱셈 연산의 성능을 향상시키기 위해 레지스터 압력 감소와 연산 재구성에 초점을 맞춘 중요한 개선 사항을 포함하고 있습니다. getMmaEmulationTileShape 함수의 정교한 타일 크기 결정과 emitI8DotDecomposition 함수의 데이터 분할 및 재조립 로직은 GPU 하드웨어의 효율성을 극대화하는 좋은 예시입니다. 이러한 최적화는 Triton이 더 빠르고 효율적인 딥러닝 모델을 구축하는 데 기여할 것입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글