[Triton] AsyncCompileMode 에러 발생 시 active_mode 초기화 보장
PR 링크: triton-lang/triton#9491 상태: Merged | 변경: +22 / -1
들어가며
Triton의 AsyncCompileMode는 커널 컴파일을 비동기로 처리하는 context manager다. 하지만 ignore_errors=False 상태에서 컴파일 에러가 발생하면, active_mode가 None으로 리셋되지 않아 후속 비동기 컴파일이 불가능해지는 버그가 있었다. 이 PR은 에러 발생 여부와 관계없이 active_mode를 먼저 None으로 설정한다.
핵심 코드 분석
Before: active_mode 리셋이 마지막에 위치
def __exit__(self, exc_type, exc_value, traceback):
# Finalize any outstanding compiles
for future in as_completed(self.raw_futures):
self.future_kernels[future._key].result(self.ignore_errors)
active_mode.set(None)
as_completed 루프 내에서 result() 호출 시 예외가 발생하면, active_mode.set(None)에 도달하지 못한다. 이 상태에서 사용자가 예외를 catch하고 새 AsyncCompileMode를 사용하려 하면, 이미 tear down된 executor를 참조하게 된다.
After: active_mode 리셋을 맨 앞으로 이동
def __exit__(self, exc_type, exc_value, traceback):
active_mode.set(None)
# Finalize any outstanding compiles
for future in as_completed(self.raw_futures):
self.future_kernels[future._key].result(self.ignore_errors)
단 한 줄의 이동이지만, 에러 경로에서의 상태 일관성을 보장한다.
테스트 추가
def test_async_compile_error(fresh_triton_cache):
@triton.jit
def fn(x: tl.constexpr):
tl.static_assert(x == 2)
with pytest.raises(triton.compiler.errors.CompileTimeAssertionFailure):
with (
ThreadPoolExecutor(2) as pool,
triton.AsyncCompileMode(pool),
):
assert triton.runtime._async_compile.active_mode.get() is not None
fn.warmup(1, grid=(1, ))
# 에러 후에도 active_mode가 None으로 정리되었는지 확인
assert triton.runtime._async_compile.active_mode.get() is None
왜 이게 좋은가
- 상태 일관성: 예외 발생 여부와 관계없이
active_mode가 항상 정리된다. - 재사용 안전: 에러 후 새로운
AsyncCompileMode진입이 안전하게 가능해진다. - 최소한의 변경: 한 줄 이동으로 문제를 해결한다.
- 테스트로 검증: 에러 시나리오를 재현하는 테스트가 포함되어 regression을 방지한다.
정리
이 PR은 AsyncCompileMode.__exit__에서 active_mode.set(None) 호출 위치를 에러 핸들링보다 앞으로 이동하여, 컴파일 에러가 전파되더라도 상태가 올바르게 초기화되도록 수정했다. 간단하지만 비동기 컴파일의 안정성에 중요한 수정이다.
참고 자료
이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.
관련 포스트
- [Triton] 모듈 언로드 테스트 비결정적 실패 수정 — GC 비활성화로 안정성 확보
- [Triton] HIPBackend에서 import torch 가드 추가 — JAX 호환성 복원
- [triton] AMD Canonicalize Pointers에서 arith.select의 비대칭 fat pointer 처리 강화
- [CPython] 64-bit ARM 커널에서 32-bit ARM Android의 sysconfig ABI 감지 수정
- [Axolotl] 플러그인에 scored rollout 디스패치, 외부 플러그인 경로 확장, vLLM 에러 처리 개선
PR Analysis 의 다른글
- 이전글 [Grafana Loki] 오브젝트 스토어 클라이언트에 요청 레이턴시 히스토그램 메트릭 추가
- 현재글 : [Triton] AsyncCompileMode 에러 발생 시 active_mode 초기화 보장
- 다음글 [Ray] 다중 입력 연산자의 메모리 귀속 오류 수정으로 데드락 해결
댓글