본문으로 건너뛰기

[Triton] gfx1250에 async_copy_local_to_global 추가

PR 링크: triton-lang/triton#8984 상태: Merged | 변경: +678 / -79

들어가며

AMD GFX1250은 global-to-LDS(shared) 비동기 복사를 지원하지만, 그 반대 방향인 shared-to-global 비동기 복사는 별도로 구현이 필요했다. 이 PR은 Gluon 프론트엔드에서 async_copy_local_to_global 연산을 정의하고, LLVM IR로의 lowering과 UpdateAsyncWaitCount에서의 카운팅을 추가한다.

핵심 코드 분석

Gluon 프론트엔드 API

@builtin
def shared_to_global(pointer, smem, mask=None, cache_modifier="", _semantic=None):
    """
    Asynchronously copy elements from shared memory to global memory.
    Requires manual synchronization via `wait_group` before accessing
    the stored data.
    """
    _check(pointer.type.is_block(), lambda: "expected ptr to be a tensor")
    _check(isinstance(pointer.type.layout, DistributedLayout),
           lambda: "expected ptr type layout to be BlockedLayout or SliceLayout")
    _check(smem.shape == pointer.shape,
           lambda: f"expected smem shape to match pointer shape")
    # ...
    _semantic.builder.create_async_copy_local_to_global(
        smem.handle, pointer.handle, mask_handle,
        cache_modifier, ir.EVICTION_POLICY.NORMAL)

LLVM lowering (128-bit store 예시)

// 8 elements x 32bit = 256 bits -> 2x 128-bit async stores
// CHECK-COUNT-2: llvm.amdgcn.global.store.async.from.lds.b128
// CHECK-NOT: llvm.amdgcn.global.store.async.from.lds
%2 = amdg.async_copy_local_to_global %arg1, %arg0
    : !ttg.memdesc<32x32xf32, #shared, #smem, mutable>
    -> tensor<32x32x!tt.ptr<f32>, #blocked>

UpdateAsyncWaitCount에서 store 연산 카운팅

# Gluon 사용법
gfx1250_async_copy.shared_to_global(ptr + offsets, smem)
gfx1250_async_copy.shared_to_global(ptr + offsets, smem, mask)
gfx1250_async_copy.commit_group()

왜 이게 좋은가

  1. 양방향 비동기 복사: global-to-shared 뿐 아니라 shared-to-global도 비동기로 처리하여 compute-transfer 오버랩이 가능해진다.
  2. Gluon 통합: Gluon 프론트엔드에서 shared_to_global API로 직관적으로 사용할 수 있다.
  3. 벡터화 지원: contiguity 정보를 활용하여 32/64/128-bit store 중 최적을 선택한다.
  4. Wait count 통합: UpdateAsyncWaitCount에서 store 연산도 카운팅하여 올바른 동기화를 보장한다.

정리

이 PR은 AMD GFX1250의 Gluon 프론트엔드에 async_copy_local_to_global 연산을 추가했다. Op 정의, LLVM lowering, wait count 계산, 프론트엔드 API, lit 테스트를 포함하는 완전한 구현이다.

참고 자료


이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.

댓글

관련 포스트

PR Analysis 의 다른글