[Triton] M=64 2CTA 모드 지원 추가
PR 링크: triton-lang/triton#8922 상태: Merged | 변경: +87 / -31
들어가며
NVIDIA Blackwell GPU에서 tcgen05.mma 명령어는 2CTA(Cooperative Thread Array) 모드로 실행될 수 있다. 기존에는 M=128 instruction shape만 2CTA를 지원했지만, 이 PR은 M=64에서도 2CTA 모드를 사용할 수 있도록 LinearLayout 변환 로직을 확장한다.
2CTA 모드에서는 두 개의 CTA가 협력하여 하나의 큰 행렬 연산을 수행한다. M=64의 경우 CTA 분할 방식이 M=128과 다르기 때문에, basis 스왑 로직을 새로 구현해야 했다.
핵심 코드 분석
Before (LinearLayoutConversions.cpp)
if (isM64TwoCTA) {
auto bases = ret.getBases();
std::swap(bases[kRow].back(), bases[kCol].back());
ret = LinearLayout(std::move(bases), ret.getOutDims(), ret.isSurjective());
}
단순히 마지막 basis끼리 swap하는 방식이었다.
After (LinearLayoutConversions.cpp)
if (isM64TwoCTA) {
auto bases = ret.getBases();
auto basisCTA1 =
llvm::Log2_32(encoding.getBlockN() * encoding.getColStride()) - 1;
std::swap(bases[kRow].back(), bases[kCol][basisCTA1]);
ret = LinearLayout(std::move(bases), ret.getOutDims(), ret.isSurjective());
}
blockN과 colStride를 기반으로 정확한 CTA1 basis 인덱스를 계산하여 swap한다. 이로써 M=64에서도 올바른 데이터 분배가 이루어진다.
왜 이게 좋은가
- 유연성 확대: M=64 instruction shape에서도 2CTA를 사용할 수 있어 더 다양한 커널 구성이 가능
- 정확한 레이아웃 계산:
blockN * colStride를 기반으로 한 basis 인덱스 계산으로 데이터 일관성 보장 - Shared memory 관리 개선: 테스트에서
max_shared_mem검사를 추가하여 OOM 방지
정리
2CTA 모드는 GPU의 병렬성을 극대화하는 핵심 기능이다. M=64 지원 추가로 Blackwell 아키텍처의 성능을 더 폭넓게 활용할 수 있게 되었다. LinearLayout의 basis 변환이라는 추상적인 개념이 실제 하드웨어 동작과 어떻게 매핑되는지 잘 보여주는 PR이다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] [Blackwell] NVIDIA 차세대 아키텍처를 위한 Triton의 tcgen05.ld.red 최적화 분석
- 현재글 : [Triton] M=64 2CTA 모드 지원 추가
- 다음글 [llm-compressor] Memoryless Observers - 메모리 효율적 가중치 관찰자
댓글