본문으로 건너뛰기

[Triton] gfx1250에서 async_copy multicast 지원

PR 링크: triton-lang/triton#8719 상태: Merged | 변경: +255 / -20

들어가며

GPU에서 여러 CTA(Cooperative Thread Array)가 같은 데이터를 사용할 때, 각 CTA가 개별적으로 global memory에서 로드하면 메모리 대역폭이 낭비된다. Multicast는 하나의 로드로 여러 CTA의 shared memory에 동시에 데이터를 쓰는 기술이다. 이 PR은 AMD gfx1250의 llvm.amdgcn.cluster.load.async.to.lds 명령어를 활용하여 이 기능을 구현한다.

핵심 코드 분석

Multicast mask 계산

CTASplitNum과 CTAsPerCGA를 분석하여 어떤 CTA들이 같은 데이터를 공유하는지 결정하고, 그에 맞는 비트 마스크를 생성한다.

핵심 테스트 케이스:

// 8 CTAs, 2 multicast groups of 4 CTAs each
// base mask = 0b1010101 (85), non free mask = -7 (~0b110)
#blocked = #ttg.blocked<{CTAsPerCGA = [8, 1], CTASplitNum = [2, 1]}>

// CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-7 : i32) : i32
// CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]]
// CHECK: %[[GROUP_MASK:.*]] = llvm.mlir.constant(85 : i32) : i32
// CHECK: %[[CTA_MASK:.*]] = llvm.shl %[[GROUP_MASK]], %[[SHIFT_AMOUNT]]

free dimension(공유 차원)과 non-free dimension을 구분하여 base mask를 만들고, 현재 CTA ID에 따라 shift하여 최종 multicast mask를 생성한다.

데이터 공유가 없는 경우

// 16 CTAs split into 16 multicast groups — no sharing
// CHECK-NOT: llvm.amdgcn.cluster.load.async.to.lds
// CHECK: llvm.amdgcn.global.load.async.to.lds.b64

CTASplitNum이 CTAsPerCGA와 같으면 각 CTA가 독립적인 데이터를 사용하므로, 일반 async load를 사용한다.

왜 이게 좋은가

  1. 메모리 대역폭 절약: 하나의 로드로 여러 CTA의 LDS에 동시 기록
  2. 자동 최적화: 레이아웃 분석을 통해 multicast가 가능한 경우를 자동 감지
  3. Fallback: 공유 데이터가 없으면 자동으로 일반 로드 사용

정리

Multicast는 multi-CTA 커널의 핵심 최적화다. 특히 대규모 matmul이나 attention에서 같은 KV 블록을 여러 CTA가 공유할 때 큰 성능 향상을 가져온다. 비트 마스크 기반의 CTA 그룹핑 로직은 다양한 CTA 토폴로지를 효율적으로 지원한다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글

관련 포스트

PR Analysis 의 다른글