[Triton] HIPBackend에서 import torch 가드 추가 — JAX 호환성 복원
PR 링크: triton-lang/triton#9441 상태: Merged | 변경: +15 / -2
들어가며
Triton은 PyTorch뿐 아니라 JAX 등 다른 프레임워크에서도 사용된다. 그런데 AMD HIPBackend의 is_within_2gb() 메서드에 import torch가 무조건 실행되도록 들어가 있어, torch가 설치되지 않은 환경(예: jax-triton)에서 ImportError가 발생했다.
핵심 코드 분석
Before
@staticmethod
def is_within_2gb(arg):
import torch # 매번 무조건 import
MAX_INT_32 = 2**31 - 1
if hasattr(arg, "ptr_range"):
return arg.ptr_range() <= MAX_INT_32
if isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage"):
return arg.untyped_storage().size() <= MAX_INT_32
return False
After
_torch_available: None | bool = None
@staticmethod
def is_within_2gb(arg):
if HIPBackend._torch_available is None:
try:
import torch
HIPBackend._torch_available = True
except ImportError:
HIPBackend._torch_available = False
elif HIPBackend._torch_available:
import torch
MAX_INT_32 = 2**31 - 1
if hasattr(arg, "ptr_range"):
return arg.ptr_range() <= MAX_INT_32
if HIPBackend._torch_available and isinstance(arg, torch.Tensor) \
and hasattr(arg, "untyped_storage"):
return arg.untyped_storage().size() <= MAX_INT_32
return False
왜 이게 좋은가
- 호환성: torch가 없는 환경에서도 HIPBackend가 정상 동작한다.
- 성능: 클래스 변수로 torch 가용성을 캐싱하여, 첫 호출 이후에는
try/except오버헤드가 없다. - 안전성:
isinstance(arg, torch.Tensor)검사 전에_torch_available플래그를 확인하여, torch가 없을 때NameError를 방지한다.
정리
단순한 import torch 하나가 프레임워크 호환성을 깨뜨린 사례다. lazy import + 캐싱 패턴으로 해결하여, Triton의 프레임워크 독립성을 복원했다.
참고 자료
이 글은 AI 도구의 도움을 받아 작성되었습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] NVIDIA TMA im2col 모드 Gluon 튜토리얼 - Convolution 커널 구현
- 현재글 : [Triton] HIPBackend에서 import torch 가드 추가 — JAX 호환성 복원
- 다음글 [triton] 컴파일된 커널 모듈 명시적 unload 지원
댓글