본문으로 건너뛰기

[triton] AMD Async Wait Count에서 Warp Free Variable 및 Register Zero Base 버그 수정

PR 링크: triton-lang/triton#9732 상태: Merged | 변경: +64 / -3

들어가며

GPU에서 async copy 연산의 wait count를 정확히 계산하는 것은 동기화 정합성의 핵심입니다. 잘못된 wait count는 데이터 레이스나 불필요한 대기를 유발할 수 있습니다. 이 PR은 AMD 백엔드에서 async copy 명령어 수 산출 시 발생하던 두 가지 독립적인 버그를 수정합니다.

핵심 코드 분석

1. Warp Free Variable 처리

warp 차원에 free variable가 있으면 비정규(non-canonical) warp는 async copy를 실행하지 않습니다. 기존 코드는 이를 고려하지 않았습니다.

Before:

LinearLayout globalToSharedLayout =
    globalLayout.invertAndCompose(sharedLayout);
contig = std::min(contig, globalToSharedLayout.getNumConsecutiveInOut());
// ... 바로 register 수 계산으로 진행

After:

LinearLayout globalToSharedLayout =
    globalLayout.invertAndCompose(sharedLayout);

auto kWarp = StringAttr::get(globalType.getContext(), "warp");
if (globalToSharedLayout.getFreeVariableMasks().lookup(kWarp) != 0) {
    return 0;  // 비정규 warp는 load를 건너뜀
}

2. Register Zero Base 제거

register 차원에서 zero base가 있으면 여러 register index가 동일한 offset에 매핑되어 실제 load 명령어가 생성되지 않습니다.

Before:

int numberOfRegisters = globalToSharedLayout.getInDimSize(
    StringAttr::get(globalType.getContext(), "register"));
return std::max(1, numberOfRegisters / contig);

After:

auto kReg = StringAttr::get(globalType.getContext(), "register");
int numberOfRegisters =
    globalToSharedLayout.removeZeroBasesAlongDim(kReg).getInDimSize(kReg);
return std::max(1, numberOfRegisters / contig);

왜 이게 좋은가

이 두 수정은 모두 wait count의 과대 계산을 방지합니다. 과대 계산된 wait count는 GPU가 실제로 존재하지 않는 명령어 완료를 불필요하게 기다리게 하여 성능 저하를 유발합니다. 반대로 과소 계산은 데이터 정합성 문제를 야기할 수 있어, 정확한 계산이 필수적입니다. LinearLayout의 free variable mask와 zero base 제거를 활용한 해법은 수학적으로 정확하고 일반적입니다.

정리

Warp free variable이 있는 레이아웃에서 명령어 수를 0으로 반환하고, register zero base를 제거하여 실제 생성되는 명령어 수만 카운트하도록 수정했습니다.

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글