본문으로 건너뛰기

[Triton] 2CTA Block Scale MMA with tcgen05.cp — 두 CTA 협력 행렬 곱셈

PR 링크: triton-lang/triton#9460 상태: Merged | 변경: +343 / -84

들어가며

NVIDIA Blackwell 아키텍처의 tcgen05(5세대 Tensor Core Generator)는 두 CTA가 협력하여 행렬 곱셈을 수행하는 2-CTA 모드를 지원한다. Block Scale MMA는 행렬의 블록 단위로 스케일링을 적용하는 저정밀 행렬 곱셈이다.

이 PR은 2-CTA Block Scale MMA의 전체 데이터 경로(TMA load → tcgen05.cp → MMA → commit)를 구현하여, 두 CTA가 스케일 데이터를 공유하면서 행렬 곱셈을 수행할 수 있게 한다.

핵심 코드 분석

전체 파이프라인 흐름

CTA 0: TMA load(A_tile) → shared memory → tcgen05.cp → TMEM → MMA
CTA 1: TMA load(B_tile) → shared memory → tcgen05.cp → TMEM → MMA
                     ↑ scale data를 cross-CTA shared memory로 공유

tcgen05.cp를 통한 scale 데이터 복사

// tcgen05.cp: shared memory → TMEM (Tensor Memory) 복사
// 2-CTA 모드에서는 다른 CTA의 shared memory에서도 읽기 가능
void lowerTCGen05CpOp(TCGen05CpOp op) {
  // cp.async.bulk.tensor.shared::cta.tmem [dst], [src]
  // 2-CTA 모드: src가 partner CTA의 shared memory일 수 있음
}

TMA barrier 수정

// Before: 단일 CTA의 TMA 완료만 대기
mbarrier.expect(bar, tma_bytes);
mbarrier.wait(bar, phase);

// After: 2-CTA에서 scale data의 cross-CTA 전달도 포함
// CTA 0이 scale을 로드하고 CTA 1에서 사용하려면
// barrier가 양쪽 CTA의 완료를 보장해야 함
mbarrier.expect(bar, tma_bytes + scale_bytes);
// TMA load 완료 + partner CTA의 scale 복사 완료 후 진행

Bug Fix: barrier 카운트 수정

// Before: scale의 TMA 바이트 수가 barrier expect에 미포함
auto expectedBytes = computeTMABytes(operandA) +
                     computeTMABytes(operandB);

// After: scale 데이터도 포함
auto expectedBytes = computeTMABytes(operandA) +
                     computeTMABytes(operandB) +
                     computeTMABytes(scaleA) +
                     computeTMABytes(scaleB);

왜 이게 좋은가

  1. 2-CTA 활용: 두 CTA의 리소스를 합쳐 더 큰 타일을 처리하여 shared memory와 레지스터 활용도를 높인다.
  2. Block Scale 지원: 저정밀 행렬 곱셈에 블록 단위 스케일을 적용하여, 정확도와 성능의 균형을 맞춘다.
  3. barrier 정합성: scale 데이터의 TMA 전송도 barrier에 올바르게 포함시켜 데이터 정합성을 보장한다.

정리

이 PR은 Blackwell의 2-CTA Block Scale MMA를 tcgen05.cp 명령으로 구현한다. TMA → cp → MMA → commit의 전체 파이프라인과 cross-CTA scale 데이터 공유를 포함하며, barrier 카운트 버그도 수정한다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 핵심 코드와 explaination은 실제 PR diff를 기반으로 합니다.

댓글

관련 포스트

PR Analysis 의 다른글