[triton] Generic Multi-CTA convert_layout 지원
PR 링크: triton-lang/triton#9317 상태: Merged | 변경: +426 / -439
들어가며
Triton에서 convert_layout은 텐서의 레이아웃을 변환하는 핵심 연산입니다. 기존에는 multi-CTA(여러 CTA가 클러스터로 묶인 환경)에서 CTA 간 데이터 이동이 필요한 경우 별도로 처리했습니다. 이 PR은 warp-sync, CTA-sync, cluster-sync를 일관된 코드 경로로 통합합니다.
핵심 코드 분석
1. isCvtDimSync 범용화
Before:
bool isCvtWarpSync(const triton::LinearLayout &srcLayout,
const triton::LinearLayout &dstLayout) {
auto comp = dstLayout.invertAndCompose(srcLayout);
auto kWarp = StringAttr::get(ctx, "warp");
return comp.isTrivialOver(kWarp) &&
srcLayout.getFreeVariableMasks()[kWarp] == 0 &&
dstLayout.getFreeVariableMasks()[kWarp] == 0;
}
After:
bool isCvtDimSync(const triton::LinearLayout &srcLayout,
const triton::LinearLayout &dstLayout, StringAttr dim) {
auto kWarp = StringAttr::get(ctx, "warp");
auto kBlock = StringAttr::get(ctx, "block");
assert((dim == kWarp || dim == kBlock));
auto parentTrivial = true;
if (dim == kWarp) {
parentTrivial = isCvtDimSync(srcLayout, dstLayout, kBlock);
}
auto comp = dstLayout.invertAndCompose(srcLayout);
return parentTrivial && comp.isTrivialOver(dim) &&
srcLayout.getFreeVariableMasks()[dim] == 0 &&
dstLayout.getFreeVariableMasks()[dim] == 0;
}
기존 isCvtWarpSync를 isCvtDimSync로 일반화하여 warp뿐 아니라 block(CTA) 차원에서도 sync 필요 여부를 판단합니다. 재귀적으로 상위 차원(block)의 trivial 여부를 먼저 확인합니다.
2. clusterBarrier 인터페이스 추가
virtual void clusterBarrier(Location loc, RewriterBase &rewriter) const = 0;
TargetInfoBase에 cluster-level barrier 인터페이스를 추가하여, CTA 간 동기화가 필요한 convert_layout에서 사용합니다.
왜 이게 좋은가
- 코드 통합: warp/CTA/cluster 레벨 sync를 하나의 코드 경로로 처리하여 유지보수가 용이합니다.
- 정확한 동기화: 재귀적 dim-sync 검사로 불필요한 barrier를 줄이면서도 정확성을 보장합니다.
- 확장성: 새로운 차원이 추가되더라도
isCvtDimSync의 재귀 구조로 쉽게 확장 가능합니다.
정리
이 PR은 Triton의 layout conversion을 multi-CTA 환경에 맞게 범용화합니다. isCvtDimSync 함수의 재귀적 차원별 sync 검사와 clusterBarrier 인터페이스 도입이 핵심입니다.
참고 자료
이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Triton] TMA im2col 모드 — Gluon API 구현
- 현재글 : [triton] Generic Multi-CTA convert_layout 지원
- 다음글 [triton] 클러스터 환경을 위한 Membar 패스 확장
댓글