본문으로 건너뛰기

[triton] Generic Multi-CTA convert_layout 지원

PR 링크: triton-lang/triton#9317 상태: Merged | 변경: +426 / -439

들어가며

Triton에서 convert_layout은 텐서의 레이아웃을 변환하는 핵심 연산입니다. 기존에는 multi-CTA(여러 CTA가 클러스터로 묶인 환경)에서 CTA 간 데이터 이동이 필요한 경우 별도로 처리했습니다. 이 PR은 warp-sync, CTA-sync, cluster-sync를 일관된 코드 경로로 통합합니다.

핵심 코드 분석

1. isCvtDimSync 범용화

Before:

bool isCvtWarpSync(const triton::LinearLayout &srcLayout,
                   const triton::LinearLayout &dstLayout) {
  auto comp = dstLayout.invertAndCompose(srcLayout);
  auto kWarp = StringAttr::get(ctx, "warp");
  return comp.isTrivialOver(kWarp) &&
         srcLayout.getFreeVariableMasks()[kWarp] == 0 &&
         dstLayout.getFreeVariableMasks()[kWarp] == 0;
}

After:

bool isCvtDimSync(const triton::LinearLayout &srcLayout,
                  const triton::LinearLayout &dstLayout, StringAttr dim) {
  auto kWarp = StringAttr::get(ctx, "warp");
  auto kBlock = StringAttr::get(ctx, "block");
  assert((dim == kWarp || dim == kBlock));
  auto parentTrivial = true;
  if (dim == kWarp) {
    parentTrivial = isCvtDimSync(srcLayout, dstLayout, kBlock);
  }
  auto comp = dstLayout.invertAndCompose(srcLayout);
  return parentTrivial && comp.isTrivialOver(dim) &&
         srcLayout.getFreeVariableMasks()[dim] == 0 &&
         dstLayout.getFreeVariableMasks()[dim] == 0;
}

기존 isCvtWarpSyncisCvtDimSync로 일반화하여 warp뿐 아니라 block(CTA) 차원에서도 sync 필요 여부를 판단합니다. 재귀적으로 상위 차원(block)의 trivial 여부를 먼저 확인합니다.

2. clusterBarrier 인터페이스 추가

virtual void clusterBarrier(Location loc, RewriterBase &rewriter) const = 0;

TargetInfoBase에 cluster-level barrier 인터페이스를 추가하여, CTA 간 동기화가 필요한 convert_layout에서 사용합니다.

왜 이게 좋은가

  • 코드 통합: warp/CTA/cluster 레벨 sync를 하나의 코드 경로로 처리하여 유지보수가 용이합니다.
  • 정확한 동기화: 재귀적 dim-sync 검사로 불필요한 barrier를 줄이면서도 정확성을 보장합니다.
  • 확장성: 새로운 차원이 추가되더라도 isCvtDimSync의 재귀 구조로 쉽게 확장 가능합니다.

정리

이 PR은 Triton의 layout conversion을 multi-CTA 환경에 맞게 범용화합니다. isCvtDimSync 함수의 재귀적 차원별 sync 검사와 clusterBarrier 인터페이스 도입이 핵심입니다.

참고 자료


이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글