[Triton] AMD TDM AsyncWait을 UpdateAsyncWaitCount에서 지원
PR 링크: triton-lang/triton#9352 상태: Merged | 변경: +320 / -48
들어가며
AMD GPU에서 TDM(Tensor Descriptor Mode)의 scatter/gather 연산은 하나의 TTGIR 연산에서 여러 개의 하드웨어 intrinsic/assembly 명령어를 생성할 수 있다. async_wait의 waitcnt는 하드웨어 명령어 수 기준이므로, TTGIR 연산 수 대신 실제 intrinsic 수를 기반으로 계산해야 한다. 이 PR은 기존의 async_copy_global_to_local 분석 pass를 확장하여 TDM 연산도 처리한다.
핵심 코드 분석
Before: TDM wait은 별도 처리 없음
기존 async_tdm_wait은 num attribute에 사용자가 직접 지정한 값을 그대로 사용했다.
// Before: 사용자가 직접 num을 지정
amdg.async_tdm_wait {num = 0 : i32}
After: async_tdm_intrinsic_wait으로 변환 + 자동 count 계산
// After: UpdateAsyncWaitCount가 실제 intrinsic 수를 계산
// 64xi32 indices: 64/8 = 8 instructions per gather/scatter
amdg.async_tdm_gather %tensorDesc[%row_indices_i32, %c0_i32] to %memDesc
amdg.async_tdm_scatter %tensorDesc[%row_indices_i32, %c0_i32] from %memDesc
// 256xi16 indices: 256/16 = 16 instructions per gather/scatter
amdg.async_tdm_gather %tensorDesc[%row_indices_i16, %c0_i32] to %memDesc
amdg.async_tdm_scatter %tensorDesc[%row_indices_i16, %c0_i32] from %memDesc
// CHECK: amdg.async_tdm_intrinsic_wait {count = 0} // 모든 48개 대기
// CHECK: amdg.async_tdm_intrinsic_wait {count = 16} // 마지막 scatter 제외
// CHECK: amdg.async_tdm_intrinsic_wait {count = 32} // 마지막 2개 제외
// CHECK: amdg.async_tdm_intrinsic_wait {count = 40} // 첫 scatter만 대기
// CHECK: amdg.async_tdm_intrinsic_wait {count = 48} // 첫 gather만 대기
혼합 사용 시 독립적 카운팅
TDM과 일반 async_copy는 서로 다른 하드웨어 카운터를 사용하므로 독립적으로 계산된다.
// TDM과 async_copy 혼합 시
%1 = amdg.async_tdm_copy_global_to_local ... // TDM counter
%2 = ttg.async_copy_global_to_local ... // async copy counter
%3 = amdg.async_tdm_copy_global_to_local ... // TDM counter
// CHECK: amdg.async_tdm_intrinsic_wait {count = 1} // TDM만 카운트
// CHECK: amdg.async_wait {num_inst = 2} // async copy만 카운트
// CHECK: amdg.async_tdm_intrinsic_wait {count = 0} // TDM만 카운트
왜 이게 좋은가
- 정확한 동기화: TDM scatter/gather가 생성하는 실제 intrinsic 수를 기반으로 waitcnt를 계산하여 데이터 정합성을 보장한다.
- Gluon 호환: token 기반 approach와 TTGIR 연산 수 기반 approach 모두 지원한다.
- 카운터 분리: TDM과 일반 async copy의 하드웨어 카운터를 독립적으로 관리한다.
- index 타입별 계산: i32(64/8=8개), i16(256/16=16개) 등 index 타입에 따른 intrinsic 수를 정확히 반영한다.
정리
이 PR은 AMD GPU의 TDM scatter/gather 연산에 대한 정확한 async_wait count 계산을 UpdateAsyncWaitCount pass에 추가했다. 하나의 TTGIR 연산이 여러 하드웨어 명령어로 변환되는 상황을 올바르게 처리하며, TDM과 async copy 카운터를 독립적으로 관리한다.
참고 자료
이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Loki] 인덱스 빌더에서 오브젝트 다운로드 시 슬라이스 사전 할당으로 메모리 효율화
- 현재글 : [Triton] AMD TDM AsyncWait을 UpdateAsyncWaitCount에서 지원
- 다음글 [Ray Data] 중복 batch_format 유효성 검사 제거
댓글