본문으로 건너뛰기

[Triton] 소규모 async_cp를 위한 최적 레이아웃 선택

PR 링크: triton-lang/triton#9183 상태: Merged | 변경: +140 / -41

들어가며

Triton 컴파일러에서 작은 크기의 load 연산은 사용자(user) 기반으로 레이아웃을 결정한다. 하지만 이 load가 async_copy로 변환될 때, source 레이아웃과 destination 레이아웃이 분리되므로 사용자 기반 레이아웃 대신 coalesced 레이아웃을 선택하는 것이 더 효율적이다. 이 PR은 소규모 async copy에 대해 독립적인 최적 레이아웃을 선택하는 패턴을 추가한다.

핵심 코드 분석

Before: 기존 ClipAsyncCopySizePerThread만 존재

기존에는 CoalesceAsyncCopy pass에 ClipAsyncCopySizePerThread 패턴만 있었다. 이 패턴은 upcast 시나리오에서 sizePerThread를 줄이는 역할만 했다.

// Before: 직접 src/mask/other를 변환하고 contiguity를 업데이트
src = convertBlockLayout(src, newBlockEnc);
if (mask)
  mask = convertBlockLayout(mask, newBlockEnc);
if (other)
  other = convertBlockLayout(other, newBlockEnc);

unsigned contiguity = axisInfoAnalysis.getContiguity(src);
if (mask)
  contiguity = std::min<unsigned>(contiguity,
                                  axisInfoAnalysis.getMaskAlignment(mask));

rewriter.modifyOpInPlace(copyOp, [&]() {
  copyOp.getSrcMutable().assign(src);
  // ...
  copyOp.setContiguity(contiguity);
});

After: 공통 함수 추출 + 새로운 CoalesceCheapAsyncCopyGlobalToLocal 패턴

레이아웃 변환 로직을 공통 함수로 추출하고, 소규모 copy를 위한 새 패턴을 추가했다.

// After: 공통 함수로 추출
static void retargetCopyOperandsToEncoding(
    AsyncCopyGlobalToLocalOp copyOp, Attribute newEncoding,
    ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternRewriter &rewriter) {
  Value src = copyOp.getSrc();
  Value mask = copyOp.getMask();
  Value other = copyOp.getOther();

  src = convertValueLayout(src, newEncoding, rewriter);
  if (mask)
    mask = convertValueLayout(mask, newEncoding, rewriter);
  if (other)
    other = convertValueLayout(other, newEncoding, rewriter);

  unsigned contiguity = axisInfoAnalysis.getContiguity(src);
  if (mask)
    contiguity =
        std::min<unsigned>(contiguity, axisInfoAnalysis.getMaskAlignment(mask));

  rewriter.modifyOpInPlace(copyOp, [&]() {
    copyOp.getSrcMutable().assign(src);
    if (mask) copyOp.getMaskMutable().assign(mask);
    if (other) copyOp.getOtherMutable().assign(other);
    copyOp.setContiguity(contiguity);
  });
}
// 새로운 패턴: 소규모 async copy에 coalesced encoding 적용
struct CoalesceCheapAsyncCopyGlobalToLocal
    : public OpRewritePattern<AsyncCopyGlobalToLocalOp> {
  LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp,
                                PatternRewriter &rewriter) const override {
    int64_t size = srcTy.getNumElements();
    // 큰 copy는 이미 coalesced -> skip
    // 32bit 미만 dtype은 contiguity 문제로 skip
    if (size >= numWarps * threadsPerWarp ||
        dstTy.getElementTypeBitWidth() < 32)
      return failure();

    auto newEnc = coalescedAsyncCopyMap[copyOp];
    if (newEnc == nullptr || newEnc == srcTy.getEncoding())
      return failure();

    retargetCopyOperandsToEncoding(copyOp, newEnc, axisInfoAnalysis, rewriter);
    return success();
  }
};

왜 이게 좋은가

  1. 불필요한 convert_layout 제거: async copy의 src/dst 레이아웃이 독립적이므로, source에 coalesced 레이아웃을 적용해 메모리 접근 효율을 높인다.
  2. 코드 재사용: 레이아웃 변환 로직을 공통 함수로 추출하여 두 패턴에서 재사용한다.
  3. 안전한 적용 범위: 32bit 이상 dtype, 소규모 텐서(numWarps * threadsPerWarp 미만)에만 적용하여 부작용을 최소화한다.
  4. buildCoalescedEncoding API 개선: 불필요한 MLIRContext* 파라미터를 제거하여 인터페이스를 단순화했다.

정리

이 PR은 Triton의 async copy 최적화 pass에 소규모 텐서를 위한 coalesced 레이아웃 선택 패턴을 추가했다. 기존에는 사용자 기반 레이아웃이 그대로 사용되어 비효율적인 메모리 접근이 발생했지만, 이제 async copy의 source 레이아웃을 독립적으로 최적화할 수 있다.

참고 자료


이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.

댓글

관련 포스트

PR Analysis 의 다른글