[triton] GFX1250에서 AsyncCopy의 OOB Shared Memory 주소를 이용한 마스킹
PR 링크: triton-lang/triton#9761 상태: Merged | 변경: +62 / -5
들어가며
Async copy 연산에서 일부 스레드의 로드를 조건부로 건너뛰어야 할 때, 전통적인 방법은 브랜치(분기)를 사용하는 것입니다. 그러나 GPU에서 브랜치는 warp divergence를 유발하여 성능에 부정적인 영향을 미칩니다. GFX1250은 per-lane LDS 주소를 지원하는데, 이 PR은 이 하드웨어 특성을 활용하여 마스킹된 스레드에 out-of-range 주소를 설정하는 방식으로 브랜치 없이 마스킹을 구현합니다.
핵심 코드 분석
Before (모든 아키텍처에서 브랜치 사용):
auto cond = b.and_(threadPred, maybeSwizzledMaskElem);
auto [loadBlock, afterLoadBlock] = emitBranch(rewriter, loc, cond);
emitAsyncLoad(rewriter, loc, targetInfo, vecBits, srcElem, shmemAddr,
op.getCache(), multicastMask);
rewriter.setInsertionPointToStart(afterLoadBlock);
After (GFX1250에서 OOB 주소로 마스킹):
if (targetInfo.supportsDirectToLDSScattering()) {
// HW will drop the load before fetching from global memory
Value outOfRangeAddress =
b.inttoptr(shmemAddr.getType(), b.i32_val(0x7FFFFFFF));
Value predicatedAddress = b.select(cond, shmemAddr, outOfRangeAddress);
emitAsyncLoad(rewriter, loc, targetInfo, vecBits, srcElem,
predicatedAddress, op.getCache(), multicastMask);
} else {
// 기존 브랜치 방식 유지
auto [loadBlock, afterLoadBlock] = emitBranch(rewriter, loc, cond);
emitAsyncLoad(...);
rewriter.setInsertionPointToStart(afterLoadBlock);
}
왜 이게 좋은가
0x7FFFFFFF라는 out-of-range LDS 주소를 사용하면 하드웨어가 global memory fetch 자체를 수행하지 않고 해당 lane의 로드를 drop합니다. 이는 브랜치 없이 마스킹을 달성하여 다음과 같은 이점을 제공합니다: (1) warp divergence 제거로 인한 실행 효율 향상, (2) 제어 흐름 그래프 단순화로 컴파일러 최적화 기회 증가, (3) 명령어 수 감소로 instruction cache 효율 개선. 기존 아키텍처에서는 여전히 브랜치 방식을 사용하여 호환성을 유지합니다.
정리
GFX1250의 per-lane LDS 주소 기능을 활용하여 async copy 마스킹을 브랜치 대신 out-of-range 주소 방식으로 구현하고, 이전 아키텍처에서는 기존 방식을 유지하도록 분기했습니다.
참고 자료
이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [ACE-Step-1.5] 외부 의존성을 걷어내고 성능을 잡다: ACE-Step 1.5의 커스텀 vLLM 엔진 도입기
- 현재글 : [triton] GFX1250에서 AsyncCopy의 OOB Shared Memory 주소를 이용한 마스킹
- 다음글 [Loki] Shard Factor 1일 때 Shuffle Shard 생략으로 메모리 50% 절감
댓글