[Triton] Gluon에서 2CTA 모드 초기 지원
들어가며
NVIDIA Blackwell GPU의 tcgen05 명령어는 2CTA 모드를 지원한다. 두 CTA가 하나의 MMA 연산을 공동으로 수행하여 더 큰 타일을 처리할 수 있다. 이 PR은 Gluon에서 layout을 분석하여 자동으로 2CTA 모드를 감지하고 디스패치하는 기능을 추가한다.
핵심 코드 분석
TensorMemoryEncoding 확장
// Before
DefaultValuedParameter<"unsigned", "1">:$CTASplitN
// After
DefaultValuedParameter<"unsigned", "1">:$CTASplitN,
DefaultValuedParameter<"bool", "false">:$twoCTAs
twoCTAs 파라미터가 추가되어 2CTA 모드의 tensor memory layout을 표현할 수 있게 되었다.
M=64 2CTA 레이아웃 수정
// 2CTA mode에서 M=64일 때의 특수 레이아웃 처리
bool isM64TwoCTA = blockM == 64 && encoding.getTwoCTAs();
if (isM64TwoCTA) {
blockM *= 2; // 128xblockN의 transpose로 처리
splitM /= 2;
}
// ...
if (isM64TwoCTA) {
// row와 col의 마지막 basis를 swap
std::swap(rowBases[rowBases.size() - 1], colBases[colBases.size() - 1]);
}
PTX 문서에 따라 M=64 2CTA는 128xblockN 레이아웃의 transpose로 구현된다.
AccelerateMatmul에서 자동 감지
Attribute accEncoding = TensorMemoryEncodingAttr::get(
context, instrShape[0], instrShape[1], colStride,
CTASplitNum[0], CTASplitNum[1], useTwoCTAs); // twoCTAs 전달
왜 이게 좋은가
- 자동 디스패치: 사용자가 layout만 지정하면 2CTA 모드 여부를 자동 판단한다.
- 정확한 메모리 레이아웃: M=64 케이스의 basis swap을 PTX 문서 기반으로 정확히 구현했다.
- 검증 강화: nCol 크기 검증이 추가되어 유효하지 않은 layout을 조기에 차단한다.
정리
+279/-172 변경으로, Blackwell 2CTA 모드의 IR 표현, LinearLayout 변환, 자동 감지 로직을 모두 포함한다. 후속 PR에서 임의 CTA 수로의 확장이 예정되어 있다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글