본문으로 건너뛰기

[triton] AMD/Gluon: gfx1250에서 async_copy 런타임 테스트 추가 및 UpdateAsyncWaitCnt 활성화

PR 링크: triton-lang/triton#8664 상태: Merged | 변경: +89 / -2

들어가며

gfx1250에서 async_copy를 올바르게 사용하기 위해서는 다양한 shared memory layout(Swizzled, Padded 등)과 데이터 타입 조합에 대한 검증이 필요합니다. 이번 PR은 이러한 런타임 테스트를 추가하고, gfx1250에서 UpdateAsyncWaitCnt를 활성화합니다.

핵심 코드 분석

테스트 커널: global -> shared -> register -> global 라운드트립

@gluon.jit
def async_load_and_write_back_kernel(a_ptr, out_ptr, M, N, ...):
    buffer = ttgl.allocate_shared_memory(...)
    ttgl.amd.gfx1250.async_copy.global_to_shared(buffer, a_ptrs)
    ttgl.amd.gfx1250.async_copy.commit_group()
    ttgl.amd.gfx1250.async_copy.wait_group(0)
    res = buffer.load(blocked_layout)
    ttgl.store(out_ptrs, res, mask)

다양한 layout 조합 테스트

ASYNC_COPY_TEST_PARAM_SHARED_LAYOUT = pytest.mark.parametrize("vec_size, shared_layout", [
    (16, ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0])),
    (4, ttgl.SwizzledSharedLayout(4, 2, 4, [1, 0])),
    (8, ttgl.SwizzledSharedLayout(8, 2, 4, [1, 0])),
    (4, ttgl.PaddedSharedLayout.with_identity_for(...)),
    (1, ttgl.SwizzledSharedLayout(1, 1, 1, [0, 1])),  # 전치 순서
])
ASYNC_COPY_TEST_PARAM_DTYPE = pytest.mark.parametrize("dtype", [
    torch.float64, torch.float32, torch.float16, torch.float8_e4m3fn
])

false positive 제거

// TargetInfo.cpp
bool addAliasGroup = localLoadOp && requiresAliasInfoForAsyncOps() &&
                     isSyncedViaAsyncWait(localLoadOp);
// gfx1250에서 불필요한 alias 정보 추가를 방지

왜 이게 좋은가

  1. 조합 테스트: layout x dtype x 크기의 다양한 조합을 체계적으로 검증합니다.
  2. 에러 케이스 처리: vec_size * dtype.itemsize < 4인 경우 RuntimeError를 올바르게 발생시키는지도 확인합니다.
  3. 재사용 가능한 구조: 후속 PR에서 확장할 수 있도록 테스트 함수와 파라미터가 분리되어 있습니다.

정리

새 아키텍처 기능의 정확성은 다양한 조합 테스트를 통해 검증해야 합니다. 이 PR은 parametrize를 활용한 체계적인 테스트 설계의 좋은 예시입니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, PR의 실제 diff를 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글