[Triton] AMD Gluon DSL에 TDM L2 Prefetch 노출 — 사용자 수준 프리페치 제어
PR 링크: triton-lang/triton#9148 상태: Merged | 변경: +278 / -1
들어가며
AMD GPU의 TDM(Tensor Data Movement) 엔진은 L2 캐시로 데이터를 미리 가져오는 prefetch 기능을 지원한다. 이전 PR(#9086)에서 하위 레벨의 amdgpu op과 LLVM lowering이 구현되었고, 이 PR은 그 위에 Gluon DSL API를 추가하여 사용자가 Python 수준에서 prefetch를 직접 제어할 수 있게 한다.
L2 prefetch는 side effect가 없고(잘못된 주소에 대해 segfault만 발생), speculative/non-speculative 두 모드를 지원한다.
핵심 코드 분석
Python API: Gluon DSL에서 prefetch 호출
# Gluon Python API
def prefetch(desc: TensorDescriptor, coord: list, *, speculative: bool = True):
"""L2 prefetch를 시작한다.
speculative=True이면 잘못된 주소도 무시하고 진행한다.
"""
...
C++ 바인딩: GluonOpBuilder에 prefetch 생성 추가
// Before: prefetch 관련 API 없음
// After: gluon_ir.cc
.def("create_tdm_prefetch",
[](GluonOpBuilder &self, Value descPtr,
std::vector<Value> &coord, Value pred,
bool speculative) {
self.create<ttag::TDMPrefetchOp>(
descPtr, coord, pred, speculative);
})
MLIR Op 정의: TDMPrefetchOp
Gluon 레벨에서 생성된 prefetch op은 TritonAMDGPU dialect의 TDMPrefetchOp으로 lowering된다.
// TritonAMDGPUOps.td에 추가된 op 정의
def TTAG_TDMPrefetchOp : TTAG_Op<"tdm_prefetch"> {
let arguments = (ins
TT_TensorDescType:$desc,
Variadic<I32>:$coords,
I1:$pred,
BoolAttr:$speculative
);
}
E2E 테스트: 실제 커널에서 prefetch 사용
@gluon.jit
def prefetch_kernel(desc):
amd.prefetch(desc, [0, 0], speculative=True)
def test_tdm_prefetch():
inp = torch.randn(128, 128, device="cuda", dtype=torch.float16)
desc = gluon.amd.gfx1250.TensorDescriptor(
base=inp, shape=[128, 128],
block_shape=[64, 64], ...)
prefetch_kernel[(1,)](desc)
왜 이게 좋은가
- 사용자 수준 성능 제어: 메모리 접근 패턴을 아는 사용자가 적절한 시점에 prefetch를 삽입하여 L2 cache hit rate를 높일 수 있다.
- 안전한 speculative 모드: speculative prefetch는 잘못된 주소에 대해서도 안전하게 무시되므로, 조건 분기 없이 공격적으로 사용할 수 있다.
- 계층적 추상화: 하드웨어 intrinsic → MLIR op → Gluon API의 3단계 추상화를 통해, 저수준 세부사항을 감추면서도 제어력을 제공한다.
정리
이 PR은 AMD TDM L2 prefetch를 Gluon DSL에서 호출할 수 있는 Python API를 추가한다. Tensor descriptor 기반의 prefetch 연산을 speculative/non-speculative 모드로 지원하며, MLIR op 정의부터 E2E 테스트까지 포함하는 완결된 구현이다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 핵심 코드와 explaination은 실제 PR diff를 기반으로 합니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [vllm] gRPC Server Entrypoint - 고성능 gRPC 서빙 지원
- 현재글 : [Triton] AMD Gluon DSL에 TDM L2 Prefetch 노출 — 사용자 수준 프리페치 제어
- 다음글 [PyTorch] MPS mul 성능 회귀 수정
댓글