본문으로 건너뛰기

[Triton] ttg.warp_id op 추가와 AMD 아키텍처별 변환 구현

들어가며

GPU 커널에서 현재 warp의 ID를 알아내는 것은 warp-level 최적화의 기본이다. 기존 Triton에서는 NVIDIA 전용 dialect에 nvg.warp_id op이 있었고, warp ID를 thread ID에서 나눗셈으로 계산했다. 이 PR은 warp_id op을 공용 ttg(TritonGPU) dialect로 이동시키고, AMD 아키텍처에서는 전용 하드웨어 명령어를 사용하도록 구현한다.

핵심 코드 분석

Before

// warp ID를 thread ID 나눗셈으로 계산
Value warpId = b.udiv(tid, warpSizeVal);

모든 아키텍처에서 thread ID를 warp 크기로 나누는 방식으로 warp ID를 얻었다. AMD에서는 이 나눗셈이 vector register에서 수행되어 비효율적이었다.

After

// 공용 warp_id op 사용
Value warpId = mlir::triton::gpu::WarpIdOp::create(rewriter, loc);

ttg.warp_id가 공용 op으로 정의되었고, AMD gfx1250에서는 llvm.amdgcn.wave.id intrinsic으로 직접 변환된다. NVIDIA에서는 기존처럼 thread ID 나눗셈을 사용하되, shuffle로 uniform 값으로 만든다.

// ttg dialect에 새 op 추가
def TTG_WarpIdOp : TTG_Op<"warp_id", [Pure]> {
  let results = (outs I32:$result);
  let assemblyFormat = "attr-dict";
}

AMD의 direct-to-LDS load에서도 warp ID hoisting 코드가 대폭 단순화되었다:

// Before: 38줄의 수동 hoisting + readfirstlane 삽입
auto insertPt = rewriter.saveInsertionPoint();
Operation *parentOp = insertPt.getBlock()->getParentOp();
while (!isa<LLVM::LLVMFuncOp>(parentOp)) { ... }
// ... 복잡한 hoisting 로직

// After: WarpIdOp이 알아서 처리
std::tie(laneId, warpId) = getLaneAndWarpId(rewriter, loc);

왜 이게 좋은가

  • 하드웨어 최적화: AMD gfx1250에서 llvm.amdgcn.wave.id를 사용하면 LLVM의 uniformity analysis가 더 잘 동작한다.
  • 코드 중복 제거: NVIDIA와 AMD 공통 코드가 ttg dialect로 통합되었다.
  • 유지보수성: 38줄의 manual hoisting 코드가 단순한 op 하나로 대체되었다.

정리

dialect 설계에서 플랫폼 공통 op을 적절히 추상화하면, 각 백엔드에서 최적의 lowering을 선택할 수 있는 유연성을 확보할 수 있다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글