본문으로 건너뛰기

[triton] AMD 비동기 복사에서 block 차원 중복 복사 허용

PR 링크: triton-lang/triton#8788 상태: Merged | 변경: +7 / -5

들어가며

Multi-CTA 환경에서 ttg.async_copy_global_to_local 연산은 각 CTA의 shared memory에 데이터를 로드합니다. 기존에는 block 차원(CTA 간 분배 차원)에서 redundant thread predicate가 활성화되어, 일부 CTA가 데이터를 로드하지 않는 문제가 있었습니다. 이 PR은 block 차원의 free variable mask를 0으로 설정하여 모든 CTA가 데이터를 로드하도록 수정합니다.

핵심 코드 분석

Before:

Value threadPred = emitRedundantThreadPredicate(getFreeVariableMasks(srcTy),
                                                rewriter, loc, targetInfo);

After:

auto freeVarMasks = getFreeVariableMasks(srcTy);
// We load redundant data on different CTAs so each CTA has a copy in its
// shared memory; the multicast mask will be used by the hardware to
// efficiently broadcast to different CTAs.
freeVarMasks[rewriter.getStringAttr("block")] = 0;
Value threadPred =
    emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo);

freeVarMasks의 "block" 차원을 0으로 설정함으로써, block 차원에서는 redundant predicate가 생성되지 않습니다. 이렇게 하면 모든 CTA에서 동일한 데이터를 로드하고, 하드웨어의 multicast mask가 효율적으로 브로드캐스트를 처리합니다.

왜 이게 좋은가

  • 정확성 보장: 각 CTA가 자신의 shared memory에 데이터를 올바르게 복사합니다.
  • 최소한의 변경: 핵심 로직은 한 줄(freeVarMasks["block"] = 0)로 해결됩니다.
  • 하드웨어 활용: multicast mask를 통한 효율적인 브로드캐스트는 하드웨어에 위임합니다.

정리

Multi-CTA async copy에서 block 차원의 redundant thread predicate를 비활성화하여 모든 CTA가 데이터를 올바르게 로드하도록 수정한 간결한 버그 수정입니다.

참고 자료


이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글