본문으로 건너뛰기

[Triton] M=64 2CTA 모드 지원 추가

PR 링크: triton-lang/triton#8922 상태: Merged | 변경: +87 / -31

들어가며

NVIDIA Blackwell GPU에서 tcgen05.mma 명령어는 2CTA(Cooperative Thread Array) 모드로 실행될 수 있다. 기존에는 M=128 instruction shape만 2CTA를 지원했지만, 이 PR은 M=64에서도 2CTA 모드를 사용할 수 있도록 LinearLayout 변환 로직을 확장한다.

2CTA 모드에서는 두 개의 CTA가 협력하여 하나의 큰 행렬 연산을 수행한다. M=64의 경우 CTA 분할 방식이 M=128과 다르기 때문에, basis 스왑 로직을 새로 구현해야 했다.

핵심 코드 분석

Before (LinearLayoutConversions.cpp)

if (isM64TwoCTA) {
  auto bases = ret.getBases();
  std::swap(bases[kRow].back(), bases[kCol].back());
  ret = LinearLayout(std::move(bases), ret.getOutDims(), ret.isSurjective());
}

단순히 마지막 basis끼리 swap하는 방식이었다.

After (LinearLayoutConversions.cpp)

if (isM64TwoCTA) {
  auto bases = ret.getBases();
  auto basisCTA1 =
      llvm::Log2_32(encoding.getBlockN() * encoding.getColStride()) - 1;
  std::swap(bases[kRow].back(), bases[kCol][basisCTA1]);
  ret = LinearLayout(std::move(bases), ret.getOutDims(), ret.isSurjective());
}

blockNcolStride를 기반으로 정확한 CTA1 basis 인덱스를 계산하여 swap한다. 이로써 M=64에서도 올바른 데이터 분배가 이루어진다.

왜 이게 좋은가

  1. 유연성 확대: M=64 instruction shape에서도 2CTA를 사용할 수 있어 더 다양한 커널 구성이 가능
  2. 정확한 레이아웃 계산: blockN * colStride를 기반으로 한 basis 인덱스 계산으로 데이터 일관성 보장
  3. Shared memory 관리 개선: 테스트에서 max_shared_mem 검사를 추가하여 OOM 방지

정리

2CTA 모드는 GPU의 병렬성을 극대화하는 핵심 기능이다. M=64 지원 추가로 Blackwell 아키텍처의 성능을 더 폭넓게 활용할 수 있게 되었다. LinearLayout의 basis 변환이라는 추상적인 개념이 실제 하드웨어 동작과 어떻게 매핑되는지 잘 보여주는 PR이다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글

관련 포스트

PR Analysis 의 다른글