본문으로 건너뛰기

[Triton] TMA im2col 모드 — Gluon API 구현

PR 링크: triton-lang/triton#9391 상태: Merged | 변경: +482 / -83

들어가며

이 PR은 NVIDIA TMA im2col 모드 시리즈의 여섯 번째로, Gluon DSL에서 im2col 모드 TMA 복사를 Python으로 직접 호출할 수 있는 API를 구현한다. im2col은 convolution의 입력 텐서를 행렬 곱셈에 적합한 형태로 변환하는 기법이며, TMA 하드웨어가 이를 자동으로 수행하면 소프트웨어 변환 오버헤드가 사라진다.

핵심 코드 분석

새로운 TensorDescriptorIm2Col 클래스

# Python API — im2col 전용 tensor descriptor
in_desc = gluon.nvidia.hopper.TensorDescriptorIm2Col(
    base=inp,
    shape=list(inp.shape),        # 전체 텐서 shape
    block_shape=[pixels, channels], # im2col 출력 블록
    im2col_offsets=[0, 0],         # 시작 오프셋
    layout=layout,
)

C++ 바인딩: im2col 타입 생성

// Before: tiled 모드만 지원
.def("get_tensor_descriptor_type", ...)

// After: im2col 전용 타입 생성 함수 추가
.def("get_tensor_descriptor_im2col_layout_type",
     [](GluonOpBuilder &self, Type blockType, bool isSigned,
        Attribute layout) -> Type {
       auto ctx = self.getContext();
       auto blockTy = cast<RankedTensorType>(blockType);
       auto blockTyLayout = blockTy.cloneWithEncoding(layout);
       return triton::nvidia_gpu::TensorDescIm2ColType::get(
           ctx, blockTyLayout);
     })

Specialization: im2col 모드 자동 감지

// specialize.cc — 커널 특수화 시 im2col 자동 감지
bool is_im2col = false;
if (has_layout && nvidia_tensor_descriptor_im2col_cls) {
    int is_inst = PyObject_IsInstance(
        arg, nvidia_tensor_descriptor_im2col_cls);
    is_im2col = is_inst == 1;
}
// cache key에 im2col 여부를 인코딩
desc_cstr = is_im2col ? "tensordesc_im2col<" : "tensordesc<";

im2col 모드에서는 input tensor의 rank 정보도 cache key에 포함된다:

if (is_im2col) {
    auto tensor_shape_obj = from_new_ref(
        PyObject_GetAttr(arg, shape_attr));
    Py_ssize_t tensor_rank = PySequence_Size(
        tensor_shape_obj.ptr());
    desc_cstr += ",input_rank=";
    desc_cstr += std::to_string(tensor_rank);
}

TMA 복사 API에 offset 전달

// Before
self.create<ttng::AsyncTMACopyGlobalToLocalOp>(
    descPtr, coord, barrier, result, pred, multicast);

// After — im2col offset을 선택적으로 전달
self.create<ttng::AsyncTMACopyGlobalToLocalOp>(
    descPtr, coord, offsetsRange, barrier, result, pred, multicast);

E2E 테스트

@gluon.jit
def tma_im2col_kernel(in_desc, out_desc):
    smem = ttgl.allocate_shared_memory(
        in_desc.dtype, in_desc.block_shape, in_desc.layout)
    bar = mbarrier.allocate_mbarrier()
    mbarrier.init(bar, count=1)
    mbarrier.expect(bar, in_desc.block_type.nbytes)
    tma.async_copy_global_to_shared_im2col(
        in_desc, [0, 0, 0, 0], [0, 0], bar, smem)
    mbarrier.wait(bar, phase=0)
    tma.async_copy_shared_to_global(out_desc, [0, 0], smem)
    tma.store_wait(pendings=0)

왜 이게 좋은가

  1. 하드웨어 가속 im2col: 소프트웨어 im2col 변환 없이 TMA가 직접 패치 추출을 수행하여 convolution 커널의 메모리 대역폭을 절약한다.
  2. Python 수준 접근: Gluon DSL을 통해 복잡한 TMA im2col 설정을 Python에서 간결하게 표현할 수 있다.
  3. JIT 캐시 호환: tensordesc_im2col<...> 형식의 cache key로 tiled/im2col 모드를 구분하여 올바른 커널 캐싱을 보장한다.

정리

이 PR은 TMA im2col 시리즈의 사용자 대면 API를 완성한다. TensorDescriptorIm2Col Python 클래스, C++ 바인딩, JIT specialization, E2E 테스트를 포함하는 완결된 구현으로, Gluon에서 convolution 최적화에 TMA 하드웨어를 직접 활용할 수 있게 한다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 핵심 코드와 explaination은 실제 PR diff를 기반으로 합니다.

댓글

관련 포스트

PR Analysis 의 다른글