[Triton] AMD TDM 연산에 multi-CTA 및 multicast 지원 추가
PR 링크: triton-lang/triton#8790 상태: Merged | 변경: +254 / -33
들어가며
AMD GPU의 CGA(Cooperative Grid Array)에서 여러 CTA가 cluster를 형성한다. Tensor Descriptor Mode(TDM)의 load/store 연산에서 multicast를 지원하면, 한 번의 메모리 요청으로 여러 CTA의 shared memory에 데이터를 동시에 쓸 수 있어 메모리 대역폭을 절약한다. 이 PR은 tt.load와 ttg.async_copy_global_to_local처럼 TDM 연산에서도 CGALayout의 broadcasting base를 감지하여 자동으로 multicast를 활성화한다.
핵심 코드 분석
Multicast 마스크 계산
CGALayout에 broadcasting base가 포함되어 있으면, 해당 차원에서 데이터가 여러 CTA에 공유된다. 이를 감지하여 multicast mask를 설정한다.
// CGALayout에서 broadcasting 차원 감지
auto cgaLayout = getDefaultCGALayout(tensorType, numCTAs);
// broadcasting base가 있으면 multicast 활성화
// 예: CGA=[2,1]에서 dim 0이 broadcast -> CTA 0,1에 동시 전송
TDM load에 multicast 적용
// Before: single CTA에만 load
amdg.async_tdm_copy_global_to_local %tensorDesc[%offset0, %offset1]
into %memDesc, %pred
: !tt.tensordesc<tensor<64x128xf16>>
-> !ttg.memdesc<64x128xf16, #shared, #smem, mutable>
// After: multicast mask와 함께 여러 CTA에 동시 load
amdg.async_tdm_copy_global_to_local %tensorDesc[%offset0, %offset1]
into %memDesc, %pred multicast_mask=%mask
: !tt.tensordesc<tensor<64x128xf16>>
-> !ttg.memdesc<64x128xf16, #shared, #smem, mutable>
왜 이게 좋은가
- 메모리 대역폭 절약: 한 번의 글로벌 메모리 요청으로 여러 CTA에 데이터를 전달하여 대역폭을 절약한다.
- 자동 감지: CGALayout의 broadcasting base를 자동으로 감지하여, 사용자가 명시적으로 multicast를 설정할 필요가 없다.
- 기존 패턴과 일관:
tt.load와ttg.async_copy_global_to_local에서 이미 사용하는 multicast 패턴을 TDM에도 동일하게 적용한다. - TDM load/store 모두 지원: global-to-local, local-to-global, gather, scatter 연산에 대해 multicast를 지원한다.
정리
이 PR은 AMD GPU의 TDM 연산에 multi-CTA multicast 지원을 추가했다. CGALayout의 broadcasting base를 감지하여 자동으로 multicast mask를 설정하고, 한 번의 메모리 요청으로 cluster 내 여러 CTA에 데이터를 전달한다.
참고 자료
이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Loki] 테넌트 rate limit 기반 셔플 샤딩으로 쿼리 성능 향상
- 현재글 : [Triton] AMD TDM 연산에 multi-CTA 및 multicast 지원 추가
- 다음글 [triton] Triton JIT 컴파일러 최적화: `inspect.getclosurevars` 제거를 통한 10,000배 성능 향상
댓글