[Triton] gfx1250에 async_copy_local_to_global 추가
PR 링크: triton-lang/triton#8984 상태: Merged | 변경: +678 / -79
들어가며
AMD GFX1250은 global-to-LDS(shared) 비동기 복사를 지원하지만, 그 반대 방향인 shared-to-global 비동기 복사는 별도로 구현이 필요했다. 이 PR은 Gluon 프론트엔드에서 async_copy_local_to_global 연산을 정의하고, LLVM IR로의 lowering과 UpdateAsyncWaitCount에서의 카운팅을 추가한다.
핵심 코드 분석
Gluon 프론트엔드 API
@builtin
def shared_to_global(pointer, smem, mask=None, cache_modifier="", _semantic=None):
"""
Asynchronously copy elements from shared memory to global memory.
Requires manual synchronization via `wait_group` before accessing
the stored data.
"""
_check(pointer.type.is_block(), lambda: "expected ptr to be a tensor")
_check(isinstance(pointer.type.layout, DistributedLayout),
lambda: "expected ptr type layout to be BlockedLayout or SliceLayout")
_check(smem.shape == pointer.shape,
lambda: f"expected smem shape to match pointer shape")
# ...
_semantic.builder.create_async_copy_local_to_global(
smem.handle, pointer.handle, mask_handle,
cache_modifier, ir.EVICTION_POLICY.NORMAL)
LLVM lowering (128-bit store 예시)
// 8 elements x 32bit = 256 bits -> 2x 128-bit async stores
// CHECK-COUNT-2: llvm.amdgcn.global.store.async.from.lds.b128
// CHECK-NOT: llvm.amdgcn.global.store.async.from.lds
%2 = amdg.async_copy_local_to_global %arg1, %arg0
: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
-> tensor<32x32x!tt.ptr<f32>, #blocked>
UpdateAsyncWaitCount에서 store 연산 카운팅
# Gluon 사용법
gfx1250_async_copy.shared_to_global(ptr + offsets, smem)
gfx1250_async_copy.shared_to_global(ptr + offsets, smem, mask)
gfx1250_async_copy.commit_group()
왜 이게 좋은가
- 양방향 비동기 복사: global-to-shared 뿐 아니라 shared-to-global도 비동기로 처리하여 compute-transfer 오버랩이 가능해진다.
- Gluon 통합: Gluon 프론트엔드에서
shared_to_globalAPI로 직관적으로 사용할 수 있다. - 벡터화 지원: contiguity 정보를 활용하여 32/64/128-bit store 중 최적을 선택한다.
- Wait count 통합:
UpdateAsyncWaitCount에서 store 연산도 카운팅하여 올바른 동기화를 보장한다.
정리
이 PR은 AMD GFX1250의 Gluon 프론트엔드에 async_copy_local_to_global 연산을 추가했다. Op 정의, LLVM lowering, wait count 계산, 프론트엔드 API, lit 테스트를 포함하는 완전한 구현이다.
참고 자료
이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] Async 연산에 명시적 의미론(Semantics) 문서 추가
- 현재글 : [Triton] gfx1250에 async_copy_local_to_global 추가
- 다음글 [triton] wgmma wait(0)를 accumulator 첫 사용 시점으로 지연하여 MMA-epilogue 오버랩 달성
댓글