[triton] AMD 비동기 복사에서 block 차원 중복 복사 허용
PR 링크: triton-lang/triton#8788 상태: Merged | 변경: +7 / -5
들어가며
Multi-CTA 환경에서 ttg.async_copy_global_to_local 연산은 각 CTA의 shared memory에 데이터를 로드합니다. 기존에는 block 차원(CTA 간 분배 차원)에서 redundant thread predicate가 활성화되어, 일부 CTA가 데이터를 로드하지 않는 문제가 있었습니다. 이 PR은 block 차원의 free variable mask를 0으로 설정하여 모든 CTA가 데이터를 로드하도록 수정합니다.
핵심 코드 분석
Before:
Value threadPred = emitRedundantThreadPredicate(getFreeVariableMasks(srcTy),
rewriter, loc, targetInfo);
After:
auto freeVarMasks = getFreeVariableMasks(srcTy);
// We load redundant data on different CTAs so each CTA has a copy in its
// shared memory; the multicast mask will be used by the hardware to
// efficiently broadcast to different CTAs.
freeVarMasks[rewriter.getStringAttr("block")] = 0;
Value threadPred =
emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo);
freeVarMasks의 "block" 차원을 0으로 설정함으로써, block 차원에서는 redundant predicate가 생성되지 않습니다. 이렇게 하면 모든 CTA에서 동일한 데이터를 로드하고, 하드웨어의 multicast mask가 효율적으로 브로드캐스트를 처리합니다.
왜 이게 좋은가
- 정확성 보장: 각 CTA가 자신의 shared memory에 데이터를 올바르게 복사합니다.
- 최소한의 변경: 핵심 로직은 한 줄(
freeVarMasks["block"] = 0)로 해결됩니다. - 하드웨어 활용: multicast mask를 통한 효율적인 브로드캐스트는 하드웨어에 위임합니다.
정리
Multi-CTA async copy에서 block 차원의 redundant thread predicate를 비활성화하여 모든 CTA가 데이터를 올바르게 로드하도록 수정한 간결한 버그 수정입니다.
참고 자료
이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Ray] iter_batches에서 프리페치 버퍼링을 올바르게 처리하여 지연시간 안정화
- 현재글 : [triton] AMD 비동기 복사에서 block 차원 중복 복사 허용
- 다음글 [Triton] clamp 최적화를 scalar에도 적용 — fmin.xorsign.abs 활용
댓글