[Triton] AMD에 MemoryCounterWaitOp과 ROCDL lowering 추가
PR 링크: triton-lang/triton#8642 상태: Merged | 변경: +174 / -17
들어가며
AMD GPU에서 메모리 연산 완료를 대기하려면 s_waitcnt 명령어를 사용한다. 문제는 각 ISA 버전(GCN, GFX9, GFX10, GFX11, GFX12)마다 카운터의 비트 인코딩이 다르다는 것이다. 이 PR은 MemoryCounterWaitOp이라는 아키텍처 독립적인 연산을 도입하여, 각 타겟에 맞는 lowering을 자동 수행한다.
핵심 코드 분석
Op 정의 (TritonAMDGPUOps.td)
def MemoryCounterWaitOp : TT_AMDGPU_Op<"memory_counter_wait"> {
let arguments = (ins
OptionalAttr<I32Attr>:$load,
OptionalAttr<I32Attr>:$store,
OptionalAttr<I32Attr>:$ds
);
let assemblyFormat = [{
oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` )
}];
}
load, store, ds 세 가지 카운터를 optional로 지정할 수 있다.
ISA별 waitcnt 인코딩
static FailureOr<unsigned> encodeWaitcnt(
llvm::AMDGPU::IsaVersion isaVersion,
unsigned vmcnt, unsigned lgkmcnt) {
if (isaVersion.Major == 9) {
// Vmcnt = Waitcnt[15:14,3:0], Lgkmcnt = Waitcnt[11:8]
unsigned lowBits = vmcnt & 0xF;
unsigned highBits = (vmcnt >> 4) << 14;
return lowBits | highBits | (expcnt << 4) | (lgkmcnt << 8);
}
if (isaVersion.Major == 11) {
// Vmcnt = Waitcnt[15:10], Lgkmcnt = Waitcnt[9:4]
return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
}
// ...
}
GFX12 이상은 전용 명령어 사용
// fgx12+: lower to ROCDL::WaitDscntOp
if (isaVersion.Major >= 12) {
// 별도의 ds_waitcnt 명령어로 lowering
}
왜 이게 좋은가
- 추상화:
amdg.memory_counter_wait ds(0)로 ISA 무관하게 표현 - 정확한 인코딩: 각 아키텍처의 비트 레이아웃을 정확하게 반영
- 유지보수성: 새 아키텍처 추가 시 인코딩 함수만 확장하면 됨
정리
GPU 동기화 명령어의 비트 인코딩은 아키텍처마다 달라서 하드코딩하면 유지보수가 어렵다. 고수준 연산으로 추상화하고 lowering 단계에서 타겟별 인코딩을 적용하는 것이 MLIR 기반 컴파일러의 장점이다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Triton] AMD LLVM 백엔드에 커스텀 스케줄러 옵션 추가
- 현재글 : [Triton] AMD에 MemoryCounterWaitOp과 ROCDL lowering 추가
- 다음글 [Triton] Concurrency Sanitizer에 TMA Store 검증 추가
댓글