[Triton] 2CTA Block Scale MMA with tcgen05.cp — 두 CTA 협력 행렬 곱셈
PR 링크: triton-lang/triton#9460 상태: Merged | 변경: +343 / -84
들어가며
NVIDIA Blackwell 아키텍처의 tcgen05(5세대 Tensor Core Generator)는 두 CTA가 협력하여 행렬 곱셈을 수행하는 2-CTA 모드를 지원한다. Block Scale MMA는 행렬의 블록 단위로 스케일링을 적용하는 저정밀 행렬 곱셈이다.
이 PR은 2-CTA Block Scale MMA의 전체 데이터 경로(TMA load → tcgen05.cp → MMA → commit)를 구현하여, 두 CTA가 스케일 데이터를 공유하면서 행렬 곱셈을 수행할 수 있게 한다.
핵심 코드 분석
전체 파이프라인 흐름
CTA 0: TMA load(A_tile) → shared memory → tcgen05.cp → TMEM → MMA
CTA 1: TMA load(B_tile) → shared memory → tcgen05.cp → TMEM → MMA
↑ scale data를 cross-CTA shared memory로 공유
tcgen05.cp를 통한 scale 데이터 복사
// tcgen05.cp: shared memory → TMEM (Tensor Memory) 복사
// 2-CTA 모드에서는 다른 CTA의 shared memory에서도 읽기 가능
void lowerTCGen05CpOp(TCGen05CpOp op) {
// cp.async.bulk.tensor.shared::cta.tmem [dst], [src]
// 2-CTA 모드: src가 partner CTA의 shared memory일 수 있음
}
TMA barrier 수정
// Before: 단일 CTA의 TMA 완료만 대기
mbarrier.expect(bar, tma_bytes);
mbarrier.wait(bar, phase);
// After: 2-CTA에서 scale data의 cross-CTA 전달도 포함
// CTA 0이 scale을 로드하고 CTA 1에서 사용하려면
// barrier가 양쪽 CTA의 완료를 보장해야 함
mbarrier.expect(bar, tma_bytes + scale_bytes);
// TMA load 완료 + partner CTA의 scale 복사 완료 후 진행
Bug Fix: barrier 카운트 수정
// Before: scale의 TMA 바이트 수가 barrier expect에 미포함
auto expectedBytes = computeTMABytes(operandA) +
computeTMABytes(operandB);
// After: scale 데이터도 포함
auto expectedBytes = computeTMABytes(operandA) +
computeTMABytes(operandB) +
computeTMABytes(scaleA) +
computeTMABytes(scaleB);
왜 이게 좋은가
- 2-CTA 활용: 두 CTA의 리소스를 합쳐 더 큰 타일을 처리하여 shared memory와 레지스터 활용도를 높인다.
- Block Scale 지원: 저정밀 행렬 곱셈에 블록 단위 스케일을 적용하여, 정확도와 성능의 균형을 맞춘다.
- barrier 정합성: scale 데이터의 TMA 전송도 barrier에 올바르게 포함시켜 데이터 정합성을 보장한다.
정리
이 PR은 Blackwell의 2-CTA Block Scale MMA를 tcgen05.cp 명령으로 구현한다. TMA → cp → MMA → commit의 전체 파이프라인과 cross-CTA scale 데이터 공유를 포함하며, barrier 카운트 버그도 수정한다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 핵심 코드와 explaination은 실제 PR diff를 기반으로 합니다.
관련 포스트
- [triton] Consumer Blackwell(sm_120)에서 PTX Codegen Segfault 수정
- [triton] Multi-CTA 튜토리얼 추가: CGA 기반 협력 연산
- [triton] Gluon tmem_load에서 Register Layout 자동 추론
- [Triton] Blackwell 2D activation-scale layout에서 ragged metadata 없이 동작하도록 수정
- [triton] Blackwell GPU Cluster Launch Control 지원으로 Persistent Kernel 워크로드 밸런싱 구현
PR Analysis 의 다른글
- 이전글 [triton] 캐시 테스트를 Device Agnostic하게 개선
- 현재글 : [Triton] 2CTA Block Scale MMA with tcgen05.cp — 두 CTA 협력 행렬 곱셈
- 다음글 [Ray] Dashboard 죽은 노드 캐시의 변수 섀도잉 버그 수정
댓글