[triton] Membar 분석 함수 호출 시 smem offset 수정
PR 링크: triton-lang/triton#9327 상태: Merged | 변경: +85 / -4
들어가며
Triton에서 함수 호출이 있을 때, callee 함수의 shared memory 접근 패턴을 caller 관점으로 변환해야 합니다. 기존에는 callee의 BlockInfo를 그대로 사용했는데, caller에서 다른 allocation이 앞에 있으면 offset이 맞지 않아 barrier가 누락될 수 있었습니다.
핵심 코드 분석
1. translateBlockInfoToCallsite 함수 추가
inline BlockInfo translateBlockInfoToCallsite(const BlockInfo &calleeBlockInfo,
size_t callOffset) {
BlockInfo translatedBlockInfo;
auto translateSlices = [&](const BlockInfo::SliceMapT &srcSlices,
BlockInfo::SliceMapT &dstSlices) {
for (const auto &[slice, ops] : srcSlices) {
auto translatedSlice =
slice.translated(callOffset, /*invalidateBufferId=*/true);
auto &dstOps = dstSlices[translatedSlice];
dstOps.insert(ops.begin(), ops.end());
}
};
translateSlices(calleeBlockInfo.syncReadSlices,
translatedBlockInfo.syncReadSlices);
translateSlices(calleeBlockInfo.syncWriteSlices,
translatedBlockInfo.syncWriteSlices);
return translatedBlockInfo;
}
2. Membar update에서 offset 적용
Before:
if (auto callee = dyn_cast<FunctionOpInterface>(
callOpInterface.resolveCallable()))
curBlockInfo = funcBlockInfoMap->lookup(callee);
After:
if (auto callee = dyn_cast<FunctionOpInterface>(
callOpInterface.resolveCallable())) {
auto calleeBlockInfo = funcBlockInfoMap->lookup(callee);
auto callBufferId = allocation->getBufferId(op);
size_t callOffset = 0;
if (callBufferId != Allocation::InvalidBufferId)
callOffset = allocation->getAllocatedInterval(callBufferId).start();
curBlockInfo = translateBlockInfoToCallsite(calleeBlockInfo, callOffset);
}
callee의 virtual buffer offset을 구한 뒤, callee의 모든 slice를 해당 offset만큼 이동시켜 caller의 물리적 주소 공간에 맞춥니다.
왜 이게 좋은가
- 정확한 aliasing 감지: Caller의 다른 allocation과 callee의 scratch buffer가 물리적으로 겹치는 경우를 정확히 감지합니다.
- 기존 분석 활용:
AllocationSlice::translated()메서드로 기존 slice 기반 분석 인프라를 재활용합니다.
정리
함수 호출 경계에서 shared memory offset이 올바르게 변환되지 않아 barrier가 누락되는 버그를 수정한 PR입니다. callee의 BlockInfo를 caller의 allocation offset만큼 평행이동시키는 것이 핵심입니다.
참고 자료
이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] 클러스터 환경을 위한 Membar 패스 확장
- 현재글 : [triton] Membar 분석 함수 호출 시 smem offset 수정
- 다음글 [Ray Serve] ClusterNodeInfoCache 정렬 버그 수정 및 중복 GCS RPC 제거로 캐시 갱신 최적화
댓글