[Triton] TMA im2col 모드 — LLVM Lowering 구현
PR 링크: triton-lang/triton#9322 상태: Merged | 변경: +316 / -48
들어가며
이 PR은 NVIDIA TMA im2col 모드 시리즈의 다섯 번째로, MLIR 수준의 im2col op을 실제 GPU에서 실행 가능한 LLVM IR로 변환하는 lowering을 구현한다. TMA im2col 모드는 convolution의 im2col 변환을 하드웨어가 직접 수행하므로, descriptor 생성과 복사 명령 모두 특수한 lowering이 필요하다.
핵심 코드 분석
im2col Tensor Descriptor 생성 lowering
TMA im2col descriptor는 tiled descriptor와 달리 convolution 관련 파라미터(base_offset, traversal_stride 등)를 추가로 필요로 한다.
// Before: tiled descriptor만 지원
void lowerMakeTensorDescOp(MakeTensorDescOp op) {
// cuTensorMapEncodeTiled(...)
}
// After: im2col descriptor 분기 추가
void lowerMakeTensorDescOp(MakeTensorDescOp op) {
if (isa<TensorDescIm2ColType>(op.getResult().getType())) {
// cuTensorMapEncodeIm2Col(
// tensorMap, tensorDataType,
// tensorRank, globalAddress, globalDim, globalStrides,
// pixelBoxLowerCorner, pixelBoxUpperCorner,
// channelsPerPixel, pixelsPerColumn,
// elementStrides, interleave, swizzle, l2Promotion, oobFill)
} else {
// cuTensorMapEncodeTiled(...)
}
}
TMA 복사에 offset 전달
im2col 모드의 TMA 복사는 im2col_offsets를 PTX 명령어에 전달해야 한다:
// Before
// cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier
// [smem_addr], [tensorMap, {coords}], [mbar]
// After — im2col 모드
// cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier
// [smem_addr], [tensorMap, {coords}], [mbar], {offsets}
void lowerAsyncTMACopyGlobalToLocalOp(
AsyncTMACopyGlobalToLocalOp op) {
auto offsets = op.getIm2colOffsets();
if (!offsets.empty()) {
// im2col 모드: offset을 PTX inline assembly에 포함
createInlineAsm(
"cp.async.bulk.tensor.${rank}d"
".shared::cluster.global.im2col.mbarrier"
"::complete_tx::bytes"
" [$0], [$1, {$coords}], [$2], {$offsets};",
...);
}
}
Verifier 강화
LogicalResult MakeTensorDescOp::verify() {
if (isa<TensorDescIm2ColType>(getResult().getType())) {
// im2col: input rank >= 3 (batch, spatial dims, channels)
if (getTensorRank() < 3)
return emitError("im2col requires at least 3D input");
// block_shape rank == 2 (pixels_per_column, channels_per_pixel)
if (getBlockShape().size() != 2)
return emitError("im2col block must be 2D");
}
return success();
}
왜 이게 좋은가
- 하드웨어 가속 im2col: CUDA
cuTensorMapEncodeIm2ColAPI를 통해 TMA 하드웨어가 직접 im2col 변환을 수행하여, 소프트웨어 변환 대비 메모리 대역폭을 절약한다. - 기존 코드 재활용: tiled lowering의 인프라(descriptor 생성, 복사 명령 생성)를 공유하면서 im2col 전용 분기만 추가하는 효율적 구조다.
- 정합성 보장: verifier를 통해 im2col/tiled 모드의 파라미터 제약을 컴파일 타임에 검사한다.
정리
이 PR은 TMA im2col 모드의 LLVM lowering을 구현한다. cuTensorMapEncodeIm2Col API를 통한 descriptor 생성과, im2col offset을 포함하는 비동기 복사 명령 생성을 포함하며, verifier로 모드별 파라미터 정합성을 보장한다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 핵심 코드와 explaination은 실제 PR diff를 기반으로 합니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] AMD GFX1250용 Warp-Pipeline f16 GEMM 예제 추가
- 현재글 : [Triton] TMA im2col 모드 — LLVM Lowering 구현
- 다음글 [triton] Triton 컴파일러 최적화: In-thread 트리 리덕션 도입
댓글