본문으로 건너뛰기

[triton] Multi-CTA 튜토리얼 추가: CGA 기반 협력 연산

PR 링크: triton-lang/triton#9654 상태: Merged | 변경: +1173 / -10

들어가며

NVIDIA Hopper부터 도입된 CGA(Cooperative Grid Array)는 최대 16개의 CTA가 하나의 클러스터로 협력할 수 있게 합니다. 클러스터 내 CTA들은 TMA broadcasting, distributed shared memory 접근, mbarrier 기반 동기화 등의 기능을 활용합니다. 이 PR은 multi-CTA 프로그래밍의 개념과 실제 구현을 다루는 포괄적인 튜토리얼을 추가합니다.

핵심 코드 분석

Multi-CTA softmax 예제의 핵심 - 여러 CTA에 걸쳐 하나의 행을 처리:

@gluon.jit
def multicta_softmax_kernel(x_ptr, out_ptr, x_row_stride, out_row_stride,
                            BLOCK_N: gl.constexpr):
    pid = gl.program_id(0)
    cga_layout = ((1,), (2,), (4,), (8,), (16,))[
        :gl.num_ctas().bit_length() - 1]
    layout = gl.BlockedLayout([4], [32], [gl.num_warps()], [0],
                              cga_layout=cga_layout)
    offs_n = gl.arange(0, BLOCK_N, layout)
    # ...
    row_max = gl.max(x, axis=0)  # 자동 cross-CTA reduction

CGA layout은 linear layout 기반으로, 각 CTA가 텐서의 어느 부분을 담당하는지 정의합니다:

# 2 CTA, dim0 방향으로 분할
gl.BlockedLayout([1, 1], [1, 32], [1, 4], [1, 0],
                 cga_layout=[[1, 0]])

# 8 CTA, dim0 4개 + dim1 2개로 분할
gl.NVMMASharedLayout.get_default_for([M, N], gl.float16,
    cga_layout=[[1, 0], [2, 0], [0, 1]])

왜 이게 좋은가

이 튜토리얼은 multi-CTA 프로그래밍의 핵심 개념(CGA layout, distributed shared memory, TMA broadcasting, cross-CTA reduction)을 실제 동작하는 코드와 함께 설명합니다. 특히 softmax 예제는 단일 CTA로는 처리할 수 없는 매우 넓은 행(64K+ 요소)을 여러 CTA가 협력하여 처리하는 실용적인 use case를 보여줍니다. Blackwell에서 cublas를 넘어서는 성능을 달성하기 위한 multi-CTA GEMM의 기반이 됩니다.

정리

CGA 기반 multi-CTA 프로그래밍의 개념 설명, softmax 예제, 성능 비교를 포함하는 1100줄 규모의 포괄적 튜토리얼을 추가했습니다.

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글