본문으로 건너뛰기

[Triton] TensorDescriptor를 tuple 인자로 전달 가능하게 지원

들어가며

Triton의 TensorDescriptor는 TMA(Tensor Memory Accelerator)를 사용하기 위한 핵심 객체다. 기존에는 커널의 최상위 인자로만 전달 가능했지만, tuple/구조체에 묶어서 전달하는 것은 불가능했다. 이 PR은 커널 인자가 중첩 tuple을 포함할 때도 TensorDescriptor를 올바르게 처리하도록 NVIDIA와 AMD 코드를 통합한다.

핵심 코드 분석

핵심 C++ 로직 (specialize.cc)

bool visit_make_tensordesc_args(PyObject *arg, PyObject *sig,
                                PyObject *relevant_paths, ...) {
    for (Py_ssize_t i = 0; i < len; ++i) {
        PyObject *s = PyTuple_GET_ITEM(sig, i);
        if (PyUnicode_CheckExact(s)) {
            std::string_view type_str(type_cstr, size);
            if (type_str.substr(0, tensordesc.length()) != tensordesc) {
                // 일반 인자: 그대로 추가
                PyList_Append(result, a);
                continue;
            }
            // TensorDescriptor 인자: make_tensordesc_arg 호출로 변환
            auto desc_args = PyObject_Vectorcall(make_tensordesc_arg, ...);
            PyList_SetSlice(result, PY_SSIZE_T_MAX, PY_SSIZE_T_MAX, desc_args);
        } else {
            // 중첩 tuple: 재귀적으로 처리
            visit_make_tensordesc_args(a, s, inner_relevant_paths, ...);
        }
    }
}

중첩 tuple에서 TensorDescriptor를 재귀적으로 탐색하고, relevant_paths dict로 TensorDescriptor가 포함된 경로만 선택적으로 처리한다.

테스트 코드

def kernel(out_ptr, payload, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
    desc0 = payload[0]  # tuple의 첫 번째: TensorDescriptor
    desc1 = payload[1]  # tuple의 두 번째: TensorDescriptor
    m_idx = payload[2]  # tuple의 세 번째: 일반 정수

    block0 = desc0.load([m_idx * M_BLOCK, N_BLOCK])
    block1 = desc1.load([m_idx * M_BLOCK, 2 * N_BLOCK])

TensorDescriptor 두 개와 정수 하나를 하나의 tuple로 묶어 커널에 전달하는 패턴이 작동한다.

왜 이게 좋은가

  • API 유연성: 복잡한 커널에서 관련 descriptor들을 구조체로 묶어 전달할 수 있다.
  • NVIDIA/AMD 통합: 거의 동일한 처리 로직을 공유하도록 통합하여 유지보수 부담을 줄였다.
  • C++ 성능: Python의 느린 재귀 대신 C++(specialize.cc)에서 처리하여 커널 호출 오버헤드를 최소화한다.

정리

+362/-179 변경으로, TensorDescriptor의 사용성을 크게 개선했다. 특히 C++ 수준에서의 recursive tuple traversal이 깔끔하게 구현되어 있다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글