본문으로 건너뛰기

[triton] NVIDIA TMA im2col 모드 드라이버 지원

PR 링크: triton-lang/triton#9305 상태: Merged | 변경: +279 / -4

들어가며

TMA im2col 모드는 convolution의 입력 데이터를 im2col 변환하면서 동시에 메모리를 복사하는 하드웨어 가속 기능입니다. 이 PR은 CUDA 드라이버의 cuTensorMapEncodeIm2col API를 Python에서 호출할 수 있도록 바인딩을 추가합니다.

핵심 코드 분석

1. cuTensorMapEncodeIm2col 타입 정의

typedef CUresult (*cuTensorMapEncodeIm2col_t)(
    CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
    cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
    const cuuint64_t *globalStrides, const int *pixelBoxLowerCorner,
    const int *pixelBoxUpperCorner, cuuint32_t channelsPerPixel,
    cuuint32_t pixelsPerColumn, const cuuint32_t *elementStrides,
    CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
    CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);

2. fillTMADescriptorIm2col 구현

static PyObject *fillTMADescriptorIm2col(PyObject *self, PyObject *args) {
  // Parse: global_address, swizzle, elemSize, elemType,
  //        blockSize, shape, strides, padding,
  //        pixelBoxLower, pixelBoxUpper,
  //        channelsPerPixel, pixelsPerColumn, elementStrides
  if (!PyArg_ParseTuple(args, "KiiiOOOiOOiiO", ...)) return NULL;

  // shape으로 tensor rank 결정 (tiled 모드와 다르게 blockSize가 아닌 shape 기준)
  int rank = PySequence_Fast_GET_SIZE(shapeFast);

  // pixel box corners와 element strides 파싱
  for (int i = 0; i < spatialRank; ++i) {
    pixelBoxLowerInt[spatialRank - i - 1] = PyLong_AsLong(item);
    pixelBoxUpperInt[spatialRank - i - 1] = PyLong_AsLong(item);
  }
  // ...
}

기존 fillTMADescriptor(tiled 전용)를 fillTMADescriptorTiled로 이름을 변경하고, im2col 전용 fillTMADescriptorIm2col을 별도로 추가했습니다. im2col은 pixel box corner, channels per pixel, pixels per column 등 convolution 특화 파라미터가 필요합니다.

왜 이게 좋은가

  • 하드웨어 가속 활용: CPU에서 im2col 변환을 하는 대신 TMA 하드웨어가 메모리 복사와 동시에 변환을 수행하여 대역폭을 절약합니다.
  • 기존 API와 병행: tiled/im2col을 별도 함수로 분리하여 기존 tiled 코드에 영향 없습니다.

정리

TMA im2col 모드의 드라이버 레벨 지원을 추가한 PR입니다. cuTensorMapEncodeIm2col API 바인딩과 Python에서 im2col descriptor를 생성하는 fillTMADescriptorIm2col 함수가 핵심입니다.

참고 자료


이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글