본문으로 건너뛰기

[Triton] HIPBackend에서 import torch 가드 추가 — JAX 호환성 복원

PR 링크: triton-lang/triton#9441 상태: Merged | 변경: +15 / -2

들어가며

Triton은 PyTorch뿐 아니라 JAX 등 다른 프레임워크에서도 사용된다. 그런데 AMD HIPBackend의 is_within_2gb() 메서드에 import torch가 무조건 실행되도록 들어가 있어, torch가 설치되지 않은 환경(예: jax-triton)에서 ImportError가 발생했다.

핵심 코드 분석

Before

@staticmethod
def is_within_2gb(arg):
    import torch  # 매번 무조건 import

    MAX_INT_32 = 2**31 - 1
    if hasattr(arg, "ptr_range"):
        return arg.ptr_range() <= MAX_INT_32
    if isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage"):
        return arg.untyped_storage().size() <= MAX_INT_32
    return False

After

_torch_available: None | bool = None

@staticmethod
def is_within_2gb(arg):
    if HIPBackend._torch_available is None:
        try:
            import torch
            HIPBackend._torch_available = True
        except ImportError:
            HIPBackend._torch_available = False
    elif HIPBackend._torch_available:
        import torch

    MAX_INT_32 = 2**31 - 1
    if hasattr(arg, "ptr_range"):
        return arg.ptr_range() <= MAX_INT_32
    if HIPBackend._torch_available and isinstance(arg, torch.Tensor) \
            and hasattr(arg, "untyped_storage"):
        return arg.untyped_storage().size() <= MAX_INT_32
    return False

왜 이게 좋은가

  • 호환성: torch가 없는 환경에서도 HIPBackend가 정상 동작한다.
  • 성능: 클래스 변수로 torch 가용성을 캐싱하여, 첫 호출 이후에는 try/except 오버헤드가 없다.
  • 안전성: isinstance(arg, torch.Tensor) 검사 전에 _torch_available 플래그를 확인하여, torch가 없을 때 NameError를 방지한다.

정리

단순한 import torch 하나가 프레임워크 호환성을 깨뜨린 사례다. lazy import + 캐싱 패턴으로 해결하여, Triton의 프레임워크 독립성을 복원했다.

참고 자료


이 글은 AI 도구의 도움을 받아 작성되었습니다.

댓글

관련 포스트

PR Analysis 의 다른글