[Triton] TMA im2col 모드 — Gluon API 구현
PR 링크: triton-lang/triton#9391 상태: Merged | 변경: +482 / -83
들어가며
이 PR은 NVIDIA TMA im2col 모드 시리즈의 여섯 번째로, Gluon DSL에서 im2col 모드 TMA 복사를 Python으로 직접 호출할 수 있는 API를 구현한다. im2col은 convolution의 입력 텐서를 행렬 곱셈에 적합한 형태로 변환하는 기법이며, TMA 하드웨어가 이를 자동으로 수행하면 소프트웨어 변환 오버헤드가 사라진다.
핵심 코드 분석
새로운 TensorDescriptorIm2Col 클래스
# Python API — im2col 전용 tensor descriptor
in_desc = gluon.nvidia.hopper.TensorDescriptorIm2Col(
base=inp,
shape=list(inp.shape), # 전체 텐서 shape
block_shape=[pixels, channels], # im2col 출력 블록
im2col_offsets=[0, 0], # 시작 오프셋
layout=layout,
)
C++ 바인딩: im2col 타입 생성
// Before: tiled 모드만 지원
.def("get_tensor_descriptor_type", ...)
// After: im2col 전용 타입 생성 함수 추가
.def("get_tensor_descriptor_im2col_layout_type",
[](GluonOpBuilder &self, Type blockType, bool isSigned,
Attribute layout) -> Type {
auto ctx = self.getContext();
auto blockTy = cast<RankedTensorType>(blockType);
auto blockTyLayout = blockTy.cloneWithEncoding(layout);
return triton::nvidia_gpu::TensorDescIm2ColType::get(
ctx, blockTyLayout);
})
Specialization: im2col 모드 자동 감지
// specialize.cc — 커널 특수화 시 im2col 자동 감지
bool is_im2col = false;
if (has_layout && nvidia_tensor_descriptor_im2col_cls) {
int is_inst = PyObject_IsInstance(
arg, nvidia_tensor_descriptor_im2col_cls);
is_im2col = is_inst == 1;
}
// cache key에 im2col 여부를 인코딩
desc_cstr = is_im2col ? "tensordesc_im2col<" : "tensordesc<";
im2col 모드에서는 input tensor의 rank 정보도 cache key에 포함된다:
if (is_im2col) {
auto tensor_shape_obj = from_new_ref(
PyObject_GetAttr(arg, shape_attr));
Py_ssize_t tensor_rank = PySequence_Size(
tensor_shape_obj.ptr());
desc_cstr += ",input_rank=";
desc_cstr += std::to_string(tensor_rank);
}
TMA 복사 API에 offset 전달
// Before
self.create<ttng::AsyncTMACopyGlobalToLocalOp>(
descPtr, coord, barrier, result, pred, multicast);
// After — im2col offset을 선택적으로 전달
self.create<ttng::AsyncTMACopyGlobalToLocalOp>(
descPtr, coord, offsetsRange, barrier, result, pred, multicast);
E2E 테스트
@gluon.jit
def tma_im2col_kernel(in_desc, out_desc):
smem = ttgl.allocate_shared_memory(
in_desc.dtype, in_desc.block_shape, in_desc.layout)
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, in_desc.block_type.nbytes)
tma.async_copy_global_to_shared_im2col(
in_desc, [0, 0, 0, 0], [0, 0], bar, smem)
mbarrier.wait(bar, phase=0)
tma.async_copy_shared_to_global(out_desc, [0, 0], smem)
tma.store_wait(pendings=0)
왜 이게 좋은가
- 하드웨어 가속 im2col: 소프트웨어 im2col 변환 없이 TMA가 직접 패치 추출을 수행하여 convolution 커널의 메모리 대역폭을 절약한다.
- Python 수준 접근: Gluon DSL을 통해 복잡한 TMA im2col 설정을 Python에서 간결하게 표현할 수 있다.
- JIT 캐시 호환:
tensordesc_im2col<...>형식의 cache key로 tiled/im2col 모드를 구분하여 올바른 커널 캐싱을 보장한다.
정리
이 PR은 TMA im2col 시리즈의 사용자 대면 API를 완성한다. TensorDescriptorIm2Col Python 클래스, C++ 바인딩, JIT specialization, E2E 테스트를 포함하는 완결된 구현으로, Gluon에서 convolution 최적화에 TMA 하드웨어를 직접 활용할 수 있게 한다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 핵심 코드와 explaination은 실제 PR diff를 기반으로 합니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [ACE-Step-1.5] Apple Silicon 맥북에서 MLX 네이티브 백엔드로 5Hz LM 추론 속도 혁신
- 현재글 : [Triton] TMA im2col 모드 — Gluon API 구현
- 다음글 [triton] Generic Multi-CTA convert_layout 지원
댓글