본문으로 건너뛰기

[Triton] Aggregate 멤버를 cache key에 포함시키기

PR 링크: triton-lang/triton#8528 상태: Merged | 변경: +38 / -12

들어가며

Triton의 JIT 컴파일러는 커널을 캐싱하여 재컴파일을 방지한다. Cache key는 커널의 인자들을 기반으로 생성되는데, aggregate 타입(사용자 정의 구조체)이 인자로 전달되면 해당 멤버들의 변경이 cache key에 반영되어야 한다. 이 PR은 aggregate의 hash_attrs를 순회하여 모든 멤버 함수의 변경을 추적한다.

핵심 코드 분석

Cache key에 aggregate 추적 추가 (jit.py)

After:

def record_reference(self, val, var_dict=None, name=None):
    if val is None or type(val) is ModuleType:
        return

    if getattr(val, "__triton_aggregate__", False):
        for attr in val.hash_attrs:
            self.record_reference(attr)
        return

__triton_aggregate__ 플래그를 감지하면, hash_attrs에 포함된 모든 멤버(생성자 + 메서드)를 재귀적으로 추적한다.

Aggregate 정의 시 hash_attrs 수집 (core.py)

After:

hash_attrs = [cls.__init__]

for (name, member) in inspect.getmembers(cls):
    if inspect.isfunction(member) or isinstance(member, JITCallable):
        if name != "__init__":
            setattr(aggregate_value, name, member)
            hash_attrs.append(member)

aggregate_value.hash_attrs = hash_attrs

왜 이게 좋은가

  1. 캐시 정확성: aggregate 내부의 JIT 함수가 변경되면 캐시가 무효화된다
  2. 테스트 검증: _unsafe_update_src로 소스를 동적으로 변경하여 캐시 무효화 테스트 추가
  3. 하위 호환성: 기존의 non-aggregate 인자 처리는 그대로 유지

정리

캐시 키 생성은 컴파일러의 정확성에 직결되는 핵심 기능이다. Aggregate 타입이 도입되면서 단순 스칼라 값 비교만으로는 부족해졌고, 멤버 함수들의 소스 코드까지 추적해야 올바른 캐시 동작이 보장된다. 참고로 이 PR은 한 번 revert(#8567)된 후 reland(#8568)되었다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글

관련 포스트

PR Analysis 의 다른글