본문으로 건너뛰기

[triton] Triton 2CTA Block-Scaled Matmul — cuBLAS 대비 성능 비교

PR 링크: triton-lang/triton#9697 상태: Merged | 변경: +990 / -0

들어가며

NVIDIA Blackwell GPU에서 block-scaled matrix multiplication은 FP4/FP8 학습과 추론의 핵심 연산이다. 이 PR은 Triton의 experimental Gluon API를 사용하여 2CTA(Cooperative Thread Array) warp-specialized block-scaled matmul을 구현한다. 두 CTA가 하나의 output tile을 협력 처리하여 SMEM 사용량을 줄이고 arithmetic intensity를 높이며, mxfp8, mxfp4, nvfp4 세 가지 포맷을 모두 지원한다.

핵심 코드 분석

1. 2CTA 협력 구조와 CGA Layout

2CTA 모드의 핵심은 CGA(Cooperative Group Array) layout 설정이다.

def mma_scaled_get_configs(pre_hook=None, cga_layouts=None):
    if cga_layouts is None:
        cga_layouts = [(), ((1, 0), )]
    return [
        triton.Config(
            {
                "BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK,
                "CGA_LAYOUT": cga_layout,
            },
            num_warps=4,
            num_ctas=2**len(cga_layout),
        )
        for BM in (128, 256)
        ...
        for cga_layout in cga_layouts
        if BM // (2**len(cga_layout)) == 128
    ]

CGA_LAYOUT=((1, 0),)은 두 CTA가 M 차원을 분할함을 의미한다. tcgen05_mma_scaled는 CTA당 BLOCK_M=128을 요구하므로, 2CTA 모드에서는 BLOCK_M=256이 된다. B operand는 두 CTA가 공유하여 TMA 로드를 절반으로 줄인다.

2. Shared Operand 전략 — B Scale Multicast

tma.async_copy_global_to_shared(
    b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar,
    b_scale_bufs.index(index), pred, multicast=multicast_b_scale
)

B 행렬과 B scale은 두 CTA가 동일한 데이터를 필요로 하므로 TMA multicast로 한 번만 로드한다. CGA layout에서 B scale의 split 차원은 [[0, 0, 0, 0, 0]] (분할 없음)으로 설정된다.

3. Online MMA with Scale Unswizzle

Scale factor는 TMA의 packed block format으로 로드되므로, SMEM에서 unswizzle이 필요하다.

@gluon.jit
def unswizzle_scales_shared_memory(smem, BLOCK_MN, BLOCK_K, VEC_SIZE):
    smem = smem.reshape((smem.shape[1], smem.shape[2], 32, 4, 4))
    smem = smem.permute((0, 3, 2, 1, 4))
    return smem.reshape((BLOCK_MN, BLOCK_K // VEC_SIZE))

Unswizzle 후 scale을 tensor memory에 복사하고, tcgen05_mma_scaled로 scaled MMA를 실행한다.

tcgen05_mma_scaled(
    a_smem, b_smem.permute((1, 0)), acc_tmem,
    a_scale_tmem, b_scale_tmem,
    a_format, b_format,
    use_acc=use_acc, pred=pred
)

4. Planar Snake Tile Scheduling

타일 스케줄링은 L2 cache 지역성을 높이는 planar snake 패턴을 사용한다.

@gluon.jit
def _planar_snake(lin_idx, m_tiles, n_tiles, minor_dim, tile_width):
    minor_tile_idx = lin_idx // (tile_width * major_size)
    full_major = gl.where(
        (minor_tile_idx % 2) == 0,
        full_major_within,
        major_size - 1 - full_major_within
    )

짝수 minor tile에서는 정방향, 홀수에서는 역방향으로 major 차원을 순회하여 인접 타일 간 데이터 재사용을 극대화한다.

왜 이게 좋은가

  • 2CTA 협력: 단일 CTA 대비 B operand 로드를 절반으로 줄여 메모리 bandwidth 효율을 높인다.
  • 3가지 포맷 지원: mxfp8 (e4m3), mxfp4 (e2m1), nvfp4를 하나의 커널에서 처리한다. VEC_SIZE가 포맷에 따라 16 또는 32로 자동 설정된다.
  • Autotuning: BLOCK_M/N/K, num_buffers, epilogue_n, grid_tile_width 등 다양한 파라미터 조합을 자동 탐색한다.
  • cuBLAS와의 직접 비교를 포함하여 Triton Gluon의 경쟁력을 검증하는 레퍼런스 구현이다.

정리

  • Blackwell의 tcgen05 MMA, TMA multicast, tensor memory 등 최신 하드웨어 기능을 Triton에서 활용하는 방법을 보여준다.
  • Warp specialization (load partition / MMA partition 분리)은 high-throughput 커널 설계의 핵심 패턴이다.
  • Scale factor의 packed block format과 unswizzle 과정은 FP4/FP8 커널 작성 시 반드시 이해해야 하는 부분이다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글