본문으로 건너뛰기

[triton] GFX1250에서 AsyncCopy의 OOB Shared Memory 주소를 이용한 마스킹

PR 링크: triton-lang/triton#9761 상태: Merged | 변경: +62 / -5

들어가며

Async copy 연산에서 일부 스레드의 로드를 조건부로 건너뛰어야 할 때, 전통적인 방법은 브랜치(분기)를 사용하는 것입니다. 그러나 GPU에서 브랜치는 warp divergence를 유발하여 성능에 부정적인 영향을 미칩니다. GFX1250은 per-lane LDS 주소를 지원하는데, 이 PR은 이 하드웨어 특성을 활용하여 마스킹된 스레드에 out-of-range 주소를 설정하는 방식으로 브랜치 없이 마스킹을 구현합니다.

핵심 코드 분석

Before (모든 아키텍처에서 브랜치 사용):

auto cond = b.and_(threadPred, maybeSwizzledMaskElem);
auto [loadBlock, afterLoadBlock] = emitBranch(rewriter, loc, cond);

emitAsyncLoad(rewriter, loc, targetInfo, vecBits, srcElem, shmemAddr,
              op.getCache(), multicastMask);

rewriter.setInsertionPointToStart(afterLoadBlock);

After (GFX1250에서 OOB 주소로 마스킹):

if (targetInfo.supportsDirectToLDSScattering()) {
    // HW will drop the load before fetching from global memory
    Value outOfRangeAddress =
        b.inttoptr(shmemAddr.getType(), b.i32_val(0x7FFFFFFF));
    Value predicatedAddress = b.select(cond, shmemAddr, outOfRangeAddress);

    emitAsyncLoad(rewriter, loc, targetInfo, vecBits, srcElem,
                  predicatedAddress, op.getCache(), multicastMask);
} else {
    // 기존 브랜치 방식 유지
    auto [loadBlock, afterLoadBlock] = emitBranch(rewriter, loc, cond);
    emitAsyncLoad(...);
    rewriter.setInsertionPointToStart(afterLoadBlock);
}

왜 이게 좋은가

0x7FFFFFFF라는 out-of-range LDS 주소를 사용하면 하드웨어가 global memory fetch 자체를 수행하지 않고 해당 lane의 로드를 drop합니다. 이는 브랜치 없이 마스킹을 달성하여 다음과 같은 이점을 제공합니다: (1) warp divergence 제거로 인한 실행 효율 향상, (2) 제어 흐름 그래프 단순화로 컴파일러 최적화 기회 증가, (3) 명령어 수 감소로 instruction cache 효율 개선. 기존 아키텍처에서는 여전히 브랜치 방식을 사용하여 호환성을 유지합니다.

정리

GFX1250의 per-lane LDS 주소 기능을 활용하여 async copy 마스킹을 브랜치 대신 out-of-range 주소 방식으로 구현하고, 이전 아키텍처에서는 기존 방식을 유지하도록 분기했습니다.

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글