[Triton] gfx1250에서 async_copy multicast 지원
PR 링크: triton-lang/triton#8719 상태: Merged | 변경: +255 / -20
들어가며
GPU에서 여러 CTA(Cooperative Thread Array)가 같은 데이터를 사용할 때, 각 CTA가 개별적으로 global memory에서 로드하면 메모리 대역폭이 낭비된다. Multicast는 하나의 로드로 여러 CTA의 shared memory에 동시에 데이터를 쓰는 기술이다. 이 PR은 AMD gfx1250의 llvm.amdgcn.cluster.load.async.to.lds 명령어를 활용하여 이 기능을 구현한다.
핵심 코드 분석
Multicast mask 계산
CTASplitNum과 CTAsPerCGA를 분석하여 어떤 CTA들이 같은 데이터를 공유하는지 결정하고, 그에 맞는 비트 마스크를 생성한다.
핵심 테스트 케이스:
// 8 CTAs, 2 multicast groups of 4 CTAs each
// base mask = 0b1010101 (85), non free mask = -7 (~0b110)
#blocked = #ttg.blocked<{CTAsPerCGA = [8, 1], CTASplitNum = [2, 1]}>
// CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-7 : i32) : i32
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
// CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(85 : i32) : i32
// CHECK: %[[CTA_MASK:.*]] = llvm.shl %[[GROUP_MASK]], %[[SHIFT_AMOUNT]]
free dimension(공유 차원)과 non-free dimension을 구분하여 base mask를 만들고, 현재 CTA ID에 따라 shift하여 최종 multicast mask를 생성한다.
데이터 공유가 없는 경우
// 16 CTAs split into 16 multicast groups — no sharing
// CHECK-NOT: llvm.amdgcn.cluster.load.async.to.lds
// CHECK: llvm.amdgcn.global.load.async.to.lds.b64
CTASplitNum이 CTAsPerCGA와 같으면 각 CTA가 독립적인 데이터를 사용하므로, 일반 async load를 사용한다.
왜 이게 좋은가
- 메모리 대역폭 절약: 하나의 로드로 여러 CTA의 LDS에 동시 기록
- 자동 최적화: 레이아웃 분석을 통해 multicast가 가능한 경우를 자동 감지
- Fallback: 공유 데이터가 없으면 자동으로 일반 로드 사용
정리
Multicast는 multi-CTA 커널의 핵심 최적화다. 특히 대규모 matmul이나 attention에서 같은 KV 블록을 여러 CTA가 공유할 때 큰 성능 향상을 가져온다. 비트 마스크 기반의 CTA 그룹핑 로직은 다양한 CTA 토폴로지를 효율적으로 지원한다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [vllm] MP Executor로 멀티 노드 분산 추론 지원
- 현재글 : [Triton] gfx1250에서 async_copy multicast 지원
- 다음글 [Loki] fsGroupChangePolicy=OnRootMismatch로 Pod 시작 속도 향상
댓글