본문으로 건너뛰기

[Triton] TMEM Store 레이아웃 변환 최적화 — FlexAttention 성능 복구

PR 링크: triton-lang/triton#8353 상태: Merged | 변경: +46 / -0

들어가며

Triton은 GPU 커널을 Python으로 작성할 수 있게 해주는 컴파일러다. 내부적으로 MLIR IR을 사용하며, 다양한 canonicalization 패턴을 통해 불필요한 연산을 제거한다. 이 PR은 TMEM(Tensor Memory) Store 앞에 붙는 불필요한 convert_layout 연산을 fold(접기)하여, FlexAttention 커널의 성능 저하를 수정한다.

NVIDIA Blackwell 아키텍처의 TMEM은 Tensor Core에 직접 연결된 고속 메모리다. tmem_store는 레지스터의 텐서를 이 TMEM에 쓰는 연산인데, 입력 텐서의 레이아웃이 TMEM과 호환되면 별도의 변환 없이 바로 저장할 수 있다.

핵심 코드 분석

Canonicalization 패턴 추가

기존 Triton에는 reshape(cvt) -> reshape, local_store(cvt) -> local_store 등의 패턴이 있었지만, tmem_store(cvt) 패턴은 없었다.

After (Ops.cpp):

// tmem_store(cvt) -> tmem_store
struct CanonicalizeConvertFromTMEMStore
    : public mlir::OpRewritePattern<nvidia_gpu::TMEMStoreOp> {
  using OpRewritePattern::OpRewritePattern;

  mlir::LogicalResult
  matchAndRewrite(nvidia_gpu::TMEMStoreOp op,
                  PatternRewriter &rewriter) const override {
    auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
    if (!convert)
      return failure();

    // bail for incompatible layouts
    auto cvtSrcType = convert.getSrc().getType();
    if (!nvidia_gpu::isDistributedLayoutTMemCompatible(
            op.getOperation(), cvtSrcType, op.getDst().getType())) {
      return failure();
    }

    rewriter.modifyOpInPlace(
        op, [&]() { op.getSrcMutable().assign(convert.getSrc()); });
    return mlir::success();
  }
};

이 패턴의 동작 원리는 다음과 같다:

  1. tmem_store의 입력이 convert_layout 연산의 결과인지 확인한다
  2. convert_layout 이전의 원본 레이아웃이 TMEM과 호환되는지 isDistributedLayoutTMemCompatible로 검증한다
  3. 호환되면 convert_layout을 건너뛰고 원본 텐서를 직접 tmem_store에 연결한다

패턴 등록

void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                   MLIRContext *context) {
  // ... 기존 패턴들 ...
  patterns.add<CanonicalizeConvertFromSplit>(context);
  patterns.add<CanonicalizeConvertFromTMEMStore>(context);  // 신규
}

MLIR 테스트

#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1],
                         warpsPerCTA = [4, 2], order = [0, 1]}>
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]],
                       lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]],
                       warp = [[32, 0], [64, 0], [16, 0]], block = []}>
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, colStride = 1>

// CHECK-NOT: ttg.convert_layout
%1 = ttg.convert_layout %arg0 : tensor<128x64xbf16, #linear> -> tensor<128x64xbf16, #blocked>
// CHECK: ttng.tmem_store %{{.*}} : tensor<128x64xbf16, #linear> ->
ttng.tmem_store %1, %arg1, %true : tensor<128x64xbf16, #blocked> -> ...

테스트는 #linear 레이아웃에서 #blocked로의 convert_layout이 제거되고, tmem_store#linear 텐서를 직접 사용하는 것을 검증한다.

왜 이게 좋은가

IR 변환 전:

%0 = ... : tensor<128x64xbf16, #linear>
%1 = ttg.convert_layout %0 : #linear -> #blocked    ← 불필요한 데이터 이동
ttng.tmem_store %1, %mem : #blocked -> tmem

IR 변환 후:

%0 = ... : tensor<128x64xbf16, #linear>
ttng.tmem_store %0, %mem : #linear -> tmem           ← 직접 저장

convert_layout은 레지스터 간 데이터 셔플링을 수반하므로, 이를 제거하면 FlexAttention 커널에서 불필요한 warp-level 통신이 사라진다. FlexAttention은 attention 연산의 핵심 경로에서 TMEM Store를 빈번하게 사용하기 때문에, 이 최적화의 효과가 누적되어 체감 성능 차이가 발생한다.

정리

컴파일러 최적화의 정석을 보여주는 PR이다. (1) 불필요한 연산 패턴을 식별하고, (2) 호환성 검증 후 안전하게 제거하며, (3) MLIR FileCheck 테스트로 검증한다. 기존 local_store, reshape 등의 fold 패턴과 동일한 구조를 따르면서 TMEM Store에 확장 적용한 점이 깔끔하다. 하드웨어 특화 메모리(TMEM)와 컴파일러 IR의 상호작용을 이해하는 데 좋은 사례다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글