본문으로 건너뛰기

[Triton] AMD Gluon에서 async_wait을 commit group 기반으로 변경

PR 링크: triton-lang/triton#8605 상태: Merged | 변경: +836 / -96

들어가며

AMD CDNA4에서 Gluon 커널을 작성할 때, async_wait에 전달해야 하는 outstanding 하드웨어 명령어 수를 수동으로 계산하는 것은 매우 어렵다. 레이아웃, 크기, contiguity에 따라 명령어 수가 달라지기 때문이다. 이 PR은 async_wait의 의미론을 NVIDIA Gluon과 동일하게 commit group 수 기반으로 변경하고, UpdateAsyncWaitCount pass가 자동으로 하드웨어 명령어 수를 계산하도록 한다.

핵심 코드 분석

Before: 사용자가 하드웨어 명령어 수를 직접 지정

# Gluon 커널에서 사용자가 직접 intrinsic 수를 계산해야 함
gfx1250_async_copy.wait(num_instructions=4)  # 레이아웃에 따라 달라짐!

After: commit group 기반으로 변경

# After: NVIDIA와 동일한 의미론
gfx1250_async_copy.global_to_shared(smem_a, ptr_a)
gfx1250_async_copy.global_to_shared(smem_b, ptr_b)
gfx1250_async_copy.commit_group()  # commit group 생성

gfx1250_async_copy.global_to_shared(smem_a2, ptr_a2)
gfx1250_async_copy.commit_group()

# outstanding commit groups 수만 지정
gfx1250_async_copy.wait_group(num=1)  # 마지막 1개 commit 대기

UpdateAsyncWaitCount: IR 역방향 탐색으로 명령어 수 자동 계산

// IR을 역방향으로 탐색하며 모든 가능한 제어 흐름 경로에서
// N개의 outstanding commit group에 해당하는 최소 명령어 수를 계산
// - scf.for: 루프 바디 내의 명령어 수 * 반복 횟수
// - scf.if: then/else 중 최소 명령어 수

왜 이게 좋은가

  1. 사용자 편의성: 레이아웃 의존적인 하드웨어 명령어 수를 직접 계산할 필요가 없어진다.
  2. NVIDIA와 통일된 의미론: Gluon에서 AMD/NVIDIA 커널이 동일한 commit_group/wait_group 패턴을 사용할 수 있다.
  3. 자동 계산: UpdateAsyncWaitCount pass가 모든 제어 흐름 경로를 분석하여 정확한 명령어 수를 계산한다.
  4. token 불필요: Gluon 커널에서는 token이 없어도 IR 역방향 탐색으로 commit group을 추적한다.

정리

이 PR은 AMD Gluon에서 async_wait의 의미론을 하드웨어 명령어 수 기반에서 commit group 수 기반으로 변경했다. UpdateAsyncWaitCount pass를 확장하여 token 없이도 IR 역방향 탐색으로 정확한 하드웨어 명령어 수를 자동 계산한다. 이를 통해 NVIDIA/AMD Gluon 커널이 동일한 패턴으로 작성 가능해진다.

참고 자료


이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.

댓글

관련 포스트

PR Analysis 의 다른글