[triton] PyTorch 없이 Triton CUDA 백엔드 독립 사용 지원
PR 링크: triton-lang/triton#9578 상태: Merged | 변경: +416 / -21
들어가며
Triton은 GPU 커널 작성 프레임워크이지만, 실행 시 PyTorch에 의존하는 부분이 있었습니다. 특히 텐서 타입 검사, 디바이스 관리 등에서 torch.Tensor를 import하는 코드가 존재했습니다. 이 PR은 PyTorch가 설치되지 않은 환경에서도 Triton CUDA 백엔드가 정상 동작하도록 의존성을 분리합니다.
핵심 코드 분석
1. torch import를 지연(lazy) 방식으로 전환
Before:
try {
torch_tensor_cls = import_from("torch", "Tensor");
} catch (py::error_already_set &e) {
}
After:
PyObject *loaded_modules = PyImport_GetModuleDict();
PyObject *torch_module = PyDict_GetItemString(loaded_modules, "torch");
if (torch_module) {
auto tensor_cls =
from_new_ref(PyObject_GetAttrString(torch_module, "Tensor"));
if (!tensor_cls)
return false;
torch_tensor_cls = tensor_cls.release().ptr();
}
기존 코드는 import torch를 시도하고 실패하면 무시하는 방식이었지만, 이 자체가 torch가 없는 환경에서 불필요한 에러를 발생시켰습니다. 새 코드는 이미 로드된 모듈 딕셔너리를 확인하여, torch가 로드되어 있을 때만 Tensor 클래스를 참조합니다.
2. CUDA 컨텍스트 자동 초기화
static void ensureCudaContext() {
CUcontext pctx;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {
CUdevice device;
CUDA_CHECK(cuDeviceGet(&device, 0));
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
}
}
PyTorch 없이는 CUDA 컨텍스트가 자동으로 초기화되지 않으므로, 드라이버 API를 직접 호출하여 컨텍스트를 생성합니다. 이로써 torch.cuda에 의존하지 않고도 GPU 디바이스를 사용할 수 있습니다.
3. torch 없는 환경 테스트
테스트 코드는 __import__를 오버라이드하여 torch import를 완전히 차단한 상태에서 커널 컴파일과 실행을 검증합니다.
def _guarded_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == "torch" or name.startswith("torch."):
torch_import_attempts.append(name)
raise ImportError("torch import is forbidden in this test")
return _orig_import(name, globals, locals, fromlist, level)
왜 이게 좋은가
이 변경은 Triton의 적용 범위를 크게 확장합니다. 임베디드 시스템, 경량 컨테이너, 또는 PyTorch 없이 순수 CUDA 작업을 수행하는 환경에서도 Triton을 사용할 수 있게 됩니다. 또한 컴파일러 파이프라인에서 torch 의존성이 제거되므로 빌드 시간과 바이너리 크기도 줄어듭니다.
정리
torch.Tensorimport를 모듈 딕셔너리 확인 방식으로 변경 (lazy)- CUDA 컨텍스트 자동 초기화로
torch.cuda의존성 제거 - GPU 메모리 할당/해제/복사 유틸리티 함수 추가
- torch import 차단 테스트로 독립 동작 검증
참고 자료
이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] Multi-CTA 예제에서 Program ID를 Shared Memory에 저장하여 재계산 방지
- 현재글 : [triton] PyTorch 없이 Triton CUDA 백엔드 독립 사용 지원
- 다음글 [Gradio] MCP 도구 호출 레이턴시 개선 — HTTP 루프백 제거
댓글