본문으로 건너뛰기

[triton] Triton Gluon을 활용한 Blackwell 아키텍처에서의 Multi-CTA 행렬 곱셈 최적화

PR 링크: triton-lang/triton#9546 상태: Merged | 변경: +933 / -17

들어가며

NVIDIA의 최신 Blackwell 아키텍처는 이전 세대보다 훨씬 강력한 연산 성능을 제공하지만, 이를 온전히 활용하기 위해서는 하드웨어의 병렬 처리 구조에 최적화된 커널 작성이 필수적입니다. 특히 Triton의 실험적 기능인 Gluon을 사용하면 저수준의 하드웨어 제어(TMA, mbarrier 등)를 더 세밀하게 다룰 수 있습니다. 이번 PR은 Blackwell 환경에서 Two CTA를 활용한 행렬 곱셈(MatMul)을 구현하며, 특히 Cluster Launch Control(CLC) 결과 버퍼의 메모리 레이아웃과 검증 로직을 개선하여 다중 CTA 간의 데이터 정합성을 확보하는 데 중점을 둡니다.

코드 분석

1. lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp: CLC 버퍼 검증 로직 강화

기존에는 CLC 결과 버퍼의 크기가 단순히 [2]로 고정되어 있었으나, 다중 CTA 환경에서는 CTA 개수에 따라 버퍼 크기가 가변적이어야 합니다. 이를 위해 verifyCLCResultMemdesc 함수가 수정되었습니다.

Before:

if (desc.getShape().size() != 1 || desc.getShape()[0] != 2) {
  return emitError(loc)
         << "Expected CLC result buffer to have shape [2], but got ["
         << triton::join(desc.getShape(), ", ") << "]";
}

After:

auto rank = desc.getRank();
if (rank != 1 || desc.getDimSize(0) != 2 * numCTAs) {
  return emitError(loc) << "Expected CLC result buffer to have rank 1 and a "
                           "single dimension equal to 2x the number of CTAs, "
                           "but got "
                        << desc.getShape() << ".";
}

이제 numCTAs를 동적으로 계산하여, 각 CTA가 자신의 결과 버퍼에 안전하게 접근할 수 있도록 검증 로직이 개선되었습니다.

2. python/examples/gluon/03-matmul-multicta.py: Multi-CTA 행렬 곱셈 구현

새롭게 추가된 예제 코드는 Two CTA 구성을 지원하며, matmul_get_configs를 통해 two_cta 플래그에 따라 num_ctas를 동적으로 할당합니다.

num_ctas=2 if two_cta else 1,
# ...
if two_cta:
    cga_layouts = [[[1, 0]], [[0, 1]], [[1, 0]]]
    for desc, cga_layout in zip(("a_desc", "b_desc", "c_desc"), cga_layouts):
        nargs[desc].layout = gl.NVMMASharedLayout.get_default_for(
            nargs[desc].block_shape,
            as_gl_dtype(nargs[desc].base.dtype),
            cga_layout=cga_layout,
        )

왜 이게 좋은가

이번 최적화의 핵심은 **'데이터의 로컬리티(Locality)'**와 **'하드웨어 자원 활용의 극대화'**입니다.

  1. 성능 향상: 벤치마크 결과, Multi-CTA 구성을 통해 단일 CTA 대비 연산 처리량이 비약적으로 상승했습니다. Blackwell의 SM(Streaming Multiprocessor) 자원을 두 개의 CTA가 협력하여 사용함으로써, 메모리 대역폭과 연산 유닛의 유휴 시간을 최소화했습니다.
  2. 유연성: 리뷰 과정에서 논의된 것처럼, multicast=True를 사용할 때 각 CTA가 자신의 로컬 버퍼를 가지도록 설계함으로써, 데이터 브로드캐스팅 시 발생할 수 있는 경합(Contention)을 방지했습니다.
  3. 교훈: GPU 프로그래밍에서 다중 CTA를 사용할 때는 단순히 연산량을 늘리는 것이 아니라, 각 CTA가 독립적인 메모리 영역을 효율적으로 점유하도록 레이아웃을 설계하는 것이 성능의 핵심입니다.

리뷰어 피드백 분석

리뷰어 peterbell10lezcano 사이의 논의는 매우 흥미롭습니다. 초기에는 CLC 버퍼를 모든 CTA가 공유하는 단일 버퍼로 생각했으나, multicast 환경에서는 각 CTA가 자신의 로컬 복사본을 가지는 것이 정합성 측면에서 안전하다는 결론에 도달했습니다. 이는 분산 시스템 설계의 원칙이 GPU 커널 레벨에서도 동일하게 적용됨을 보여줍니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글