[Triton] AMD Gluon에서 async_wait을 commit group 기반으로 변경
PR 링크: triton-lang/triton#8605 상태: Merged | 변경: +836 / -96
들어가며
AMD CDNA4에서 Gluon 커널을 작성할 때, async_wait에 전달해야 하는 outstanding 하드웨어 명령어 수를 수동으로 계산하는 것은 매우 어렵다. 레이아웃, 크기, contiguity에 따라 명령어 수가 달라지기 때문이다. 이 PR은 async_wait의 의미론을 NVIDIA Gluon과 동일하게 commit group 수 기반으로 변경하고, UpdateAsyncWaitCount pass가 자동으로 하드웨어 명령어 수를 계산하도록 한다.
핵심 코드 분석
Before: 사용자가 하드웨어 명령어 수를 직접 지정
# Gluon 커널에서 사용자가 직접 intrinsic 수를 계산해야 함
gfx1250_async_copy.wait(num_instructions=4) # 레이아웃에 따라 달라짐!
After: commit group 기반으로 변경
# After: NVIDIA와 동일한 의미론
gfx1250_async_copy.global_to_shared(smem_a, ptr_a)
gfx1250_async_copy.global_to_shared(smem_b, ptr_b)
gfx1250_async_copy.commit_group() # commit group 생성
gfx1250_async_copy.global_to_shared(smem_a2, ptr_a2)
gfx1250_async_copy.commit_group()
# outstanding commit groups 수만 지정
gfx1250_async_copy.wait_group(num=1) # 마지막 1개 commit 대기
UpdateAsyncWaitCount: IR 역방향 탐색으로 명령어 수 자동 계산
// IR을 역방향으로 탐색하며 모든 가능한 제어 흐름 경로에서
// N개의 outstanding commit group에 해당하는 최소 명령어 수를 계산
// - scf.for: 루프 바디 내의 명령어 수 * 반복 횟수
// - scf.if: then/else 중 최소 명령어 수
왜 이게 좋은가
- 사용자 편의성: 레이아웃 의존적인 하드웨어 명령어 수를 직접 계산할 필요가 없어진다.
- NVIDIA와 통일된 의미론: Gluon에서 AMD/NVIDIA 커널이 동일한
commit_group/wait_group패턴을 사용할 수 있다. - 자동 계산:
UpdateAsyncWaitCountpass가 모든 제어 흐름 경로를 분석하여 정확한 명령어 수를 계산한다. - token 불필요: Gluon 커널에서는 token이 없어도 IR 역방향 탐색으로 commit group을 추적한다.
정리
이 PR은 AMD Gluon에서 async_wait의 의미론을 하드웨어 명령어 수 기반에서 commit group 수 기반으로 변경했다. UpdateAsyncWaitCount pass를 확장하여 token 없이도 IR 역방향 탐색으로 정확한 하드웨어 명령어 수를 자동 계산한다. 이를 통해 NVIDIA/AMD Gluon 커널이 동일한 패턴으로 작성 가능해진다.
참고 자료
이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Ray] 단일 노드 LLM 배치 추론 성능 기준선 벤치마크 및 회귀 가드 추가
- 현재글 : [Triton] AMD Gluon에서 async_wait을 commit group 기반으로 변경
- 다음글 [triton] AMD: BufferLoadToLocal을 UpdateAsyncWaitCount에 포함하여 성능 회귀 수정
댓글