[Triton] TensorDescriptor를 tuple 인자로 전달 가능하게 지원
들어가며
Triton의 TensorDescriptor는 TMA(Tensor Memory Accelerator)를 사용하기 위한 핵심 객체다. 기존에는 커널의 최상위 인자로만 전달 가능했지만, tuple/구조체에 묶어서 전달하는 것은 불가능했다. 이 PR은 커널 인자가 중첩 tuple을 포함할 때도 TensorDescriptor를 올바르게 처리하도록 NVIDIA와 AMD 코드를 통합한다.
핵심 코드 분석
핵심 C++ 로직 (specialize.cc)
bool visit_make_tensordesc_args(PyObject *arg, PyObject *sig,
PyObject *relevant_paths, ...) {
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject *s = PyTuple_GET_ITEM(sig, i);
if (PyUnicode_CheckExact(s)) {
std::string_view type_str(type_cstr, size);
if (type_str.substr(0, tensordesc.length()) != tensordesc) {
// 일반 인자: 그대로 추가
PyList_Append(result, a);
continue;
}
// TensorDescriptor 인자: make_tensordesc_arg 호출로 변환
auto desc_args = PyObject_Vectorcall(make_tensordesc_arg, ...);
PyList_SetSlice(result, PY_SSIZE_T_MAX, PY_SSIZE_T_MAX, desc_args);
} else {
// 중첩 tuple: 재귀적으로 처리
visit_make_tensordesc_args(a, s, inner_relevant_paths, ...);
}
}
}
중첩 tuple에서 TensorDescriptor를 재귀적으로 탐색하고, relevant_paths dict로 TensorDescriptor가 포함된 경로만 선택적으로 처리한다.
테스트 코드
def kernel(out_ptr, payload, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
desc0 = payload[0] # tuple의 첫 번째: TensorDescriptor
desc1 = payload[1] # tuple의 두 번째: TensorDescriptor
m_idx = payload[2] # tuple의 세 번째: 일반 정수
block0 = desc0.load([m_idx * M_BLOCK, N_BLOCK])
block1 = desc1.load([m_idx * M_BLOCK, 2 * N_BLOCK])
TensorDescriptor 두 개와 정수 하나를 하나의 tuple로 묶어 커널에 전달하는 패턴이 작동한다.
왜 이게 좋은가
- API 유연성: 복잡한 커널에서 관련 descriptor들을 구조체로 묶어 전달할 수 있다.
- NVIDIA/AMD 통합: 거의 동일한 처리 로직을 공유하도록 통합하여 유지보수 부담을 줄였다.
- C++ 성능: Python의 느린 재귀 대신 C++(
specialize.cc)에서 처리하여 커널 호출 오버헤드를 최소화한다.
정리
+362/-179 변경으로, TensorDescriptor의 사용성을 크게 개선했다. 특히 C++ 수준에서의 recursive tuple traversal이 깔끔하게 구현되어 있다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글