본문으로 건너뛰기

[Triton] AMD Gluon DSL에 TDM L2 Prefetch 노출 — 사용자 수준 프리페치 제어

PR 링크: triton-lang/triton#9148 상태: Merged | 변경: +278 / -1

들어가며

AMD GPU의 TDM(Tensor Data Movement) 엔진은 L2 캐시로 데이터를 미리 가져오는 prefetch 기능을 지원한다. 이전 PR(#9086)에서 하위 레벨의 amdgpu op과 LLVM lowering이 구현되었고, 이 PR은 그 위에 Gluon DSL API를 추가하여 사용자가 Python 수준에서 prefetch를 직접 제어할 수 있게 한다.

L2 prefetch는 side effect가 없고(잘못된 주소에 대해 segfault만 발생), speculative/non-speculative 두 모드를 지원한다.

핵심 코드 분석

Python API: Gluon DSL에서 prefetch 호출

# Gluon Python API
def prefetch(desc: TensorDescriptor, coord: list, *, speculative: bool = True):
    """L2 prefetch를 시작한다.
    speculative=True이면 잘못된 주소도 무시하고 진행한다.
    """
    ...

C++ 바인딩: GluonOpBuilder에 prefetch 생성 추가

// Before: prefetch 관련 API 없음

// After: gluon_ir.cc
.def("create_tdm_prefetch",
     [](GluonOpBuilder &self, Value descPtr,
        std::vector<Value> &coord, Value pred,
        bool speculative) {
       self.create<ttag::TDMPrefetchOp>(
           descPtr, coord, pred, speculative);
     })

MLIR Op 정의: TDMPrefetchOp

Gluon 레벨에서 생성된 prefetch op은 TritonAMDGPU dialect의 TDMPrefetchOp으로 lowering된다.

// TritonAMDGPUOps.td에 추가된 op 정의
def TTAG_TDMPrefetchOp : TTAG_Op<"tdm_prefetch"> {
  let arguments = (ins
    TT_TensorDescType:$desc,
    Variadic<I32>:$coords,
    I1:$pred,
    BoolAttr:$speculative
  );
}

E2E 테스트: 실제 커널에서 prefetch 사용

@gluon.jit
def prefetch_kernel(desc):
    amd.prefetch(desc, [0, 0], speculative=True)

def test_tdm_prefetch():
    inp = torch.randn(128, 128, device="cuda", dtype=torch.float16)
    desc = gluon.amd.gfx1250.TensorDescriptor(
        base=inp, shape=[128, 128],
        block_shape=[64, 64], ...)
    prefetch_kernel[(1,)](desc)

왜 이게 좋은가

  1. 사용자 수준 성능 제어: 메모리 접근 패턴을 아는 사용자가 적절한 시점에 prefetch를 삽입하여 L2 cache hit rate를 높일 수 있다.
  2. 안전한 speculative 모드: speculative prefetch는 잘못된 주소에 대해서도 안전하게 무시되므로, 조건 분기 없이 공격적으로 사용할 수 있다.
  3. 계층적 추상화: 하드웨어 intrinsic → MLIR op → Gluon API의 3단계 추상화를 통해, 저수준 세부사항을 감추면서도 제어력을 제공한다.

정리

이 PR은 AMD TDM L2 prefetch를 Gluon DSL에서 호출할 수 있는 Python API를 추가한다. Tensor descriptor 기반의 prefetch 연산을 speculative/non-speculative 모드로 지원하며, MLIR op 정의부터 E2E 테스트까지 포함하는 완결된 구현이다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 핵심 코드와 explaination은 실제 PR diff를 기반으로 합니다.

댓글

관련 포스트

PR Analysis 의 다른글