[Triton] TMA im2col 모드 — tma load op 수정
PR 링크: triton-lang/triton#9303 상태: Merged | 변경: +56 / -72
들어가며
TMA(Tensor Memory Accelerator)는 NVIDIA Hopper 이상 GPU에서 제공하는 하드웨어 가속 메모리 복사 엔진이다. 기존에는 tiled 모드만 지원했는데, 이 시리즈는 im2col(image to column) 모드를 추가한다. im2col은 convolution 연산에서 입력 텐서의 패치를 열(column)로 변환하여 행렬 곱셈으로 처리할 수 있게 하는 기법이다.
이 PR은 시리즈의 세 번째로, AsyncTMACopyGlobalToLocalOp에 im2col offset을 전달하는 인터페이스를 수정한다.
핵심 코드 분석
Before: offset 파라미터 없음
// AsyncTMACopyGlobalToLocalOp
def TTNG_AsyncTMACopyGlobalToLocalOp :
TTNG_Op<"async_tma_copy_global_to_local"> {
let arguments = (ins
TT_TensorDescType:$desc_ptr,
Variadic<I32>:$coords,
TTNG_MBarrierType:$barrier,
TTG_MemDescType:$result,
I1:$pred,
BoolAttr:$multicast
);
}
After: im2col offset 추가
def TTNG_AsyncTMACopyGlobalToLocalOp :
TTNG_Op<"async_tma_copy_global_to_local"> {
let arguments = (ins
TT_TensorDescOrIm2ColType:$desc_ptr, // im2col 타입도 허용
Variadic<I32>:$coords,
Variadic<I16>:$im2col_offsets, // im2col offset 추가
TTNG_MBarrierType:$barrier,
TTG_MemDescType:$result,
I1:$pred,
BoolAttr:$multicast
);
}
Verifier도 업데이트되어 im2col 모드와 tiled 모드의 유효성을 구분 검증한다:
LogicalResult AsyncTMACopyGlobalToLocalOp::verify() {
auto descType = getDescPtr().getType();
if (isa<TensorDescIm2ColType>(descType)) {
// im2col: offsets가 있어야 함
if (getIm2colOffsets().empty())
return emitError("im2col mode requires offsets");
} else {
// tiled: offsets가 없어야 함
if (!getIm2colOffsets().empty())
return emitError("tiled mode should not have offsets");
}
return success();
}
왜 이게 좋은가
- 하드웨어 기능 활용: TMA im2col 모드를 사용하면 소프트웨어 im2col 변환 없이 하드웨어가 직접 패치 추출을 수행하여 메모리 대역폭을 절약한다.
- 타입 안전성:
TensorDescOrIm2ColType유니온 타입과 verifier를 통해 im2col/tiled 모드의 파라미터 사용 오류를 컴파일 타임에 잡는다. - 깔끔한 인터페이스: offset이 선택적(Variadic) 파라미터로 추가되어 기존 tiled 모드 사용자에게 영향이 없다.
정리
이 PR은 TMA load op에 im2col offset 파라미터를 추가하고, descriptor 타입을 im2col과 tiled 모두 허용하도록 확장했다. Verifier를 통해 모드별 파라미터 정합성을 보장한다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 핵심 코드와 explaination은 실제 PR diff를 기반으로 합니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] NVIDIA TMA im2col 모드 Tensor Descriptor 지원
- 현재글 : [Triton] TMA im2col 모드 — tma load op 수정
- 다음글 [Ray Data] 논리적 최적화 규칙에서 in-place 변형을 제거하여 불변성 준비
댓글