본문으로 건너뛰기

[Triton] AMD에 MemoryCounterWaitOp과 ROCDL lowering 추가

PR 링크: triton-lang/triton#8642 상태: Merged | 변경: +174 / -17

들어가며

AMD GPU에서 메모리 연산 완료를 대기하려면 s_waitcnt 명령어를 사용한다. 문제는 각 ISA 버전(GCN, GFX9, GFX10, GFX11, GFX12)마다 카운터의 비트 인코딩이 다르다는 것이다. 이 PR은 MemoryCounterWaitOp이라는 아키텍처 독립적인 연산을 도입하여, 각 타겟에 맞는 lowering을 자동 수행한다.

핵심 코드 분석

Op 정의 (TritonAMDGPUOps.td)

def MemoryCounterWaitOp : TT_AMDGPU_Op<"memory_counter_wait"> {
  let arguments = (ins
    OptionalAttr<I32Attr>:$load,
    OptionalAttr<I32Attr>:$store,
    OptionalAttr<I32Attr>:$ds
  );
  let assemblyFormat = [{
    oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` )
  }];
}

load, store, ds 세 가지 카운터를 optional로 지정할 수 있다.

ISA별 waitcnt 인코딩

static FailureOr<unsigned> encodeWaitcnt(
    llvm::AMDGPU::IsaVersion isaVersion,
    unsigned vmcnt, unsigned lgkmcnt) {
  if (isaVersion.Major == 9) {
    // Vmcnt = Waitcnt[15:14,3:0], Lgkmcnt = Waitcnt[11:8]
    unsigned lowBits = vmcnt & 0xF;
    unsigned highBits = (vmcnt >> 4) << 14;
    return lowBits | highBits | (expcnt << 4) | (lgkmcnt << 8);
  }
  if (isaVersion.Major == 11) {
    // Vmcnt = Waitcnt[15:10], Lgkmcnt = Waitcnt[9:4]
    return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
  }
  // ...
}

GFX12 이상은 전용 명령어 사용

// fgx12+: lower to ROCDL::WaitDscntOp
if (isaVersion.Major >= 12) {
    // 별도의 ds_waitcnt 명령어로 lowering
}

왜 이게 좋은가

  1. 추상화: amdg.memory_counter_wait ds(0)로 ISA 무관하게 표현
  2. 정확한 인코딩: 각 아키텍처의 비트 레이아웃을 정확하게 반영
  3. 유지보수성: 새 아키텍처 추가 시 인코딩 함수만 확장하면 됨

정리

GPU 동기화 명령어의 비트 인코딩은 아키텍처마다 달라서 하드코딩하면 유지보수가 어렵다. 고수준 연산으로 추상화하고 lowering 단계에서 타겟별 인코딩을 적용하는 것이 MLIR 기반 컴파일러의 장점이다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글

관련 포스트

PR Analysis 의 다른글