본문으로 건너뛰기

[Triton] Gluon에서 2CTA 모드 초기 지원

들어가며

NVIDIA Blackwell GPU의 tcgen05 명령어는 2CTA 모드를 지원한다. 두 CTA가 하나의 MMA 연산을 공동으로 수행하여 더 큰 타일을 처리할 수 있다. 이 PR은 Gluon에서 layout을 분석하여 자동으로 2CTA 모드를 감지하고 디스패치하는 기능을 추가한다.

핵심 코드 분석

TensorMemoryEncoding 확장

// Before
DefaultValuedParameter<"unsigned", "1">:$CTASplitN

// After
DefaultValuedParameter<"unsigned", "1">:$CTASplitN,
DefaultValuedParameter<"bool", "false">:$twoCTAs

twoCTAs 파라미터가 추가되어 2CTA 모드의 tensor memory layout을 표현할 수 있게 되었다.

M=64 2CTA 레이아웃 수정

// 2CTA mode에서 M=64일 때의 특수 레이아웃 처리
bool isM64TwoCTA = blockM == 64 && encoding.getTwoCTAs();
if (isM64TwoCTA) {
    blockM *= 2;  // 128xblockN의 transpose로 처리
    splitM /= 2;
}
// ...
if (isM64TwoCTA) {
    // row와 col의 마지막 basis를 swap
    std::swap(rowBases[rowBases.size() - 1], colBases[colBases.size() - 1]);
}

PTX 문서에 따라 M=64 2CTA는 128xblockN 레이아웃의 transpose로 구현된다.

AccelerateMatmul에서 자동 감지

Attribute accEncoding = TensorMemoryEncodingAttr::get(
    context, instrShape[0], instrShape[1], colStride,
    CTASplitNum[0], CTASplitNum[1], useTwoCTAs);  // twoCTAs 전달

왜 이게 좋은가

  • 자동 디스패치: 사용자가 layout만 지정하면 2CTA 모드 여부를 자동 판단한다.
  • 정확한 메모리 레이아웃: M=64 케이스의 basis swap을 PTX 문서 기반으로 정확히 구현했다.
  • 검증 강화: nCol 크기 검증이 추가되어 유효하지 않은 layout을 조기에 차단한다.

정리

+279/-172 변경으로, Blackwell 2CTA 모드의 IR 표현, LinearLayout 변환, 자동 감지 로직을 모두 포함한다. 후속 PR에서 임의 CTA 수로의 확장이 예정되어 있다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글