[triton] AMD GFX9 Async Copy에서 Shared Memory 순서 버그 수정
PR 링크: triton-lang/triton#9874 상태: Merged | 변경: +54 / -4
들어가며
GPU에서 async copy를 사용할 때 shared memory의 데이터 배치 순서(order)는 성능과 정합성 모두에 직접적인 영향을 미칩니다. AMD GFX9 아키텍처에서는 LDS(Local Data Share)로의 직접 scattering을 지원하지 않는 경우, 각 warp가 연속된 메모리 청크를 기록해야 합니다. 이 PR은 스레드가 contiguous 차원을 정확히 커버하는 특수한 레이아웃(sizePerThread=[1,1], threadsPerWarp=[1,64])에서 shared memory 순서가 register 순서 대신 실제 메모리 순서를 따르도록 수정한 사례입니다.
핵심 코드 분석
문제의 핵심은 getSharedEncIfAllUsersAreDotEnc 함수에서 shared memory의 order를 결정하는 로직에 있었습니다.
Before:
auto llEnc = triton::gpu::toLinearEncoding(cast<RankedTensorType>(srcTy));
auto regOrder = llEnc.getOrder();
auto threadOrder = llEnc.getThreadOrder();
SetVector<unsigned> orderSet;
auto regContig = llEnc.getContigPerThread()[regOrder[0]];
// ...
if (finalRegContig > 1)
orderSet.insert(regOrder[0]);
orderSet.insert(threadOrder.begin(), threadOrder.end());
order = orderSet.takeVector();
After:
auto llEnc = triton::gpu::toLinearEncoding(cast<RankedTensorType>(srcTy));
auto threadOrder = llEnc.getThreadOrder();
SetVector<unsigned> orderSet;
auto regContig = llEnc.getContigPerThread()[order[0]];
// ...
if (finalRegContig > 1)
orderSet.insert(order[0]);
orderSet.insert(threadOrder.begin(), threadOrder.end());
order = orderSet.takeVector();
핵심 변경은 regOrder[0]를 order[0]으로 교체한 것입니다. regOrder는 레지스터의 논리적 순서이고, order는 실제 메모리 접근 순서(contiguous dimension)입니다. sizePerThread=[1,1]이고 threadsPerWarp=[1,64]인 경우, 레지스터는 dim0 방향으로 연속되지만 실제 메모리의 contiguous 차원은 dim1(order=[1,0])입니다. 기존 코드는 이 차이를 무시하고 register order를 사용하여 잘못된 shared memory layout을 생성했습니다.
왜 이게 좋은가
이 수정은 "레지스터 순서"와 "메모리 순서"라는 두 개념의 혼동을 바로잡습니다. Vectorization을 위해 fastest dimension을 보존할 때, register에서 연속인 차원이 아니라 실제 메모리에서 연속이면서 contiguity > 1인 차원을 기준으로 해야 합니다. 이는 특히 warp 내 스레드 수가 특정 차원을 완전히 커버하는 edge case에서 중요하며, 데이터 정합성 문제를 방지합니다.
정리
Async copy의 shared memory layout 결정 시, register order 대신 실제 메모리 order를 사용하도록 수정하여 스레드 배치가 contiguous 차원을 정확히 커버하는 경우의 정합성 문제를 해결했습니다.
참고 자료
이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [sglang] Diffusion 모델용 Fused QKNorm+RoPE CUDA 커널 추가
- 현재글 : [triton] AMD GFX9 Async Copy에서 Shared Memory 순서 버그 수정
- 다음글 [CPython] AArch64 JIT stencil에서 frame pointer 예약 활성화
댓글