[Triton] TMEM Store 레이아웃 변환 최적화 — FlexAttention 성능 복구
PR 링크: triton-lang/triton#8353 상태: Merged | 변경: +46 / -0
들어가며
Triton은 GPU 커널을 Python으로 작성할 수 있게 해주는 컴파일러다. 내부적으로 MLIR IR을 사용하며, 다양한 canonicalization 패턴을 통해 불필요한 연산을 제거한다. 이 PR은 TMEM(Tensor Memory) Store 앞에 붙는 불필요한 convert_layout 연산을 fold(접기)하여, FlexAttention 커널의 성능 저하를 수정한다.
NVIDIA Blackwell 아키텍처의 TMEM은 Tensor Core에 직접 연결된 고속 메모리다. tmem_store는 레지스터의 텐서를 이 TMEM에 쓰는 연산인데, 입력 텐서의 레이아웃이 TMEM과 호환되면 별도의 변환 없이 바로 저장할 수 있다.
핵심 코드 분석
Canonicalization 패턴 추가
기존 Triton에는 reshape(cvt) -> reshape, local_store(cvt) -> local_store 등의 패턴이 있었지만, tmem_store(cvt) 패턴은 없었다.
After (Ops.cpp):
// tmem_store(cvt) -> tmem_store
struct CanonicalizeConvertFromTMEMStore
: public mlir::OpRewritePattern<nvidia_gpu::TMEMStoreOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(nvidia_gpu::TMEMStoreOp op,
PatternRewriter &rewriter) const override {
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
// bail for incompatible layouts
auto cvtSrcType = convert.getSrc().getType();
if (!nvidia_gpu::isDistributedLayoutTMemCompatible(
op.getOperation(), cvtSrcType, op.getDst().getType())) {
return failure();
}
rewriter.modifyOpInPlace(
op, [&]() { op.getSrcMutable().assign(convert.getSrc()); });
return mlir::success();
}
};
이 패턴의 동작 원리는 다음과 같다:
tmem_store의 입력이convert_layout연산의 결과인지 확인한다convert_layout이전의 원본 레이아웃이 TMEM과 호환되는지isDistributedLayoutTMemCompatible로 검증한다- 호환되면
convert_layout을 건너뛰고 원본 텐서를 직접tmem_store에 연결한다
패턴 등록
void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// ... 기존 패턴들 ...
patterns.add<CanonicalizeConvertFromSplit>(context);
patterns.add<CanonicalizeConvertFromTMEMStore>(context); // 신규
}
MLIR 테스트
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1],
warpsPerCTA = [4, 2], order = [0, 1]}>
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]],
lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]],
warp = [[32, 0], [64, 0], [16, 0]], block = []}>
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, colStride = 1>
// CHECK-NOT: ttg.convert_layout
%1 = ttg.convert_layout %arg0 : tensor<128x64xbf16, #linear> -> tensor<128x64xbf16, #blocked>
// CHECK: ttng.tmem_store %{{.*}} : tensor<128x64xbf16, #linear> ->
ttng.tmem_store %1, %arg1, %true : tensor<128x64xbf16, #blocked> -> ...
테스트는 #linear 레이아웃에서 #blocked로의 convert_layout이 제거되고, tmem_store가 #linear 텐서를 직접 사용하는 것을 검증한다.
왜 이게 좋은가
IR 변환 전:
%0 = ... : tensor<128x64xbf16, #linear>
%1 = ttg.convert_layout %0 : #linear -> #blocked ← 불필요한 데이터 이동
ttng.tmem_store %1, %mem : #blocked -> tmem
IR 변환 후:
%0 = ... : tensor<128x64xbf16, #linear>
ttng.tmem_store %0, %mem : #linear -> tmem ← 직접 저장
convert_layout은 레지스터 간 데이터 셔플링을 수반하므로, 이를 제거하면 FlexAttention 커널에서 불필요한 warp-level 통신이 사라진다. FlexAttention은 attention 연산의 핵심 경로에서 TMEM Store를 빈번하게 사용하기 때문에, 이 최적화의 효과가 누적되어 체감 성능 차이가 발생한다.
정리
컴파일러 최적화의 정석을 보여주는 PR이다. (1) 불필요한 연산 패턴을 식별하고, (2) 호환성 검증 후 안전하게 제거하며, (3) MLIR FileCheck 테스트로 검증한다. 기존 local_store, reshape 등의 fold 패턴과 동일한 구조를 따르면서 TMEM Store에 확장 적용한 점이 깔끔하다. 하드웨어 특화 메모리(TMEM)와 컴파일러 IR의 상호작용을 이해하는 데 좋은 사례다.
참고 자료
- Triton MLIR Dialects — Triton의 MLIR dialect 구조
- NVIDIA Blackwell Architecture — Tensor Memory 소개
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] tcgen05.cp를 Generic Matrix Descriptor Lowering으로 통합
- 현재글 : [Triton] TMEM Store 레이아웃 변환 최적화 — FlexAttention 성능 복구
- 다음글 [Triton] debuginfo 테스트 단순화 — subprocess 제거
댓글