본문으로 건너뛰기

[Triton] AMD gfx1250 tt.LoadOp에 multicast 지원 추가

PR 링크: triton-lang/triton#8759 상태: Merged | 변경: +234 / -23

들어가며

AMD GFX1250의 cluster 아키텍처에서 여러 CTA가 동일한 데이터를 필요로 할 때, 각 CTA가 개별적으로 글로벌 메모리에서 읽는 것은 비효율적이다. 이 PR은 tt.LoadOp에 multicast 지원을 추가하여, async_copy_global_to_local처럼 한 번의 메모리 요청으로 여러 CTA의 레지스터에 데이터를 broadcast한다.

핵심 코드 분석

CGALayout에서 multicast 감지

CGA encoding의 broadcast dimension을 감지하여 multicast mask를 설정한다.

// CGALayout에서 broadcasting base가 있으면
// 해당 차원의 데이터가 여러 CTA에 공유됨을 의미
// -> cluster_load로 한 번에 여러 CTA에 전달

cluster_load lowering

// Before: 각 CTA가 개별 load
%val = tt.load %ptr : tensor<128x16x!tt.ptr<f16>, #blocked>

// After: multicast가 감지되면 cluster_load 사용
// 한 CTA가 로드한 데이터를 cluster 내 다른 CTA에 broadcast

왜 이게 좋은가

  1. 대역폭 절약: 같은 데이터를 읽는 여러 CTA가 하나의 메모리 요청을 공유한다.
  2. 자동 적용: CGALayout에 broadcast 차원이 있으면 자동으로 multicast가 활성화된다.
  3. async_copy와 일관성: 동일한 multicast 감지 로직을 tt.load에도 적용하여 코드 일관성을 유지한다.
  4. GFX1250 최적화: cluster 아키텍처를 활용한 하드웨어 수준의 최적화다.

정리

이 PR은 AMD GFX1250의 tt.LoadOp에 multicast 지원을 추가했다. CGALayout의 broadcasting base를 감지하여 cluster_load를 사용하고, cluster 내 여러 CTA에 한 번의 메모리 요청으로 데이터를 전달한다.

참고 자료


이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.

댓글

관련 포스트

PR Analysis 의 다른글