[Triton] ttg.warp_id op 추가와 AMD 아키텍처별 변환 구현
들어가며
GPU 커널에서 현재 warp의 ID를 알아내는 것은 warp-level 최적화의 기본이다. 기존 Triton에서는 NVIDIA 전용 dialect에 nvg.warp_id op이 있었고, warp ID를 thread ID에서 나눗셈으로 계산했다. 이 PR은 warp_id op을 공용 ttg(TritonGPU) dialect로 이동시키고, AMD 아키텍처에서는 전용 하드웨어 명령어를 사용하도록 구현한다.
핵심 코드 분석
Before
// warp ID를 thread ID 나눗셈으로 계산
Value warpId = b.udiv(tid, warpSizeVal);
모든 아키텍처에서 thread ID를 warp 크기로 나누는 방식으로 warp ID를 얻었다. AMD에서는 이 나눗셈이 vector register에서 수행되어 비효율적이었다.
After
// 공용 warp_id op 사용
Value warpId = mlir::triton::gpu::WarpIdOp::create(rewriter, loc);
ttg.warp_id가 공용 op으로 정의되었고, AMD gfx1250에서는 llvm.amdgcn.wave.id intrinsic으로 직접 변환된다. NVIDIA에서는 기존처럼 thread ID 나눗셈을 사용하되, shuffle로 uniform 값으로 만든다.
// ttg dialect에 새 op 추가
def TTG_WarpIdOp : TTG_Op<"warp_id", [Pure]> {
let results = (outs I32:$result);
let assemblyFormat = "attr-dict";
}
AMD의 direct-to-LDS load에서도 warp ID hoisting 코드가 대폭 단순화되었다:
// Before: 38줄의 수동 hoisting + readfirstlane 삽입
auto insertPt = rewriter.saveInsertionPoint();
Operation *parentOp = insertPt.getBlock()->getParentOp();
while (!isa<LLVM::LLVMFuncOp>(parentOp)) { ... }
// ... 복잡한 hoisting 로직
// After: WarpIdOp이 알아서 처리
std::tie(laneId, warpId) = getLaneAndWarpId(rewriter, loc);
왜 이게 좋은가
- 하드웨어 최적화: AMD gfx1250에서
llvm.amdgcn.wave.id를 사용하면 LLVM의 uniformity analysis가 더 잘 동작한다. - 코드 중복 제거: NVIDIA와 AMD 공통 코드가
ttgdialect로 통합되었다. - 유지보수성: 38줄의 manual hoisting 코드가 단순한 op 하나로 대체되었다.
정리
dialect 설계에서 플랫폼 공통 op을 적절히 추상화하면, 각 백엔드에서 최적의 lowering을 선택할 수 있는 유연성을 확보할 수 있다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글