[triton] Triton 2CTA Block-Scaled Matmul — cuBLAS 대비 성능 비교
PR 링크: triton-lang/triton#9697 상태: Merged | 변경: +990 / -0
들어가며
NVIDIA Blackwell GPU에서 block-scaled matrix multiplication은 FP4/FP8 학습과 추론의 핵심 연산이다. 이 PR은 Triton의 experimental Gluon API를 사용하여 2CTA(Cooperative Thread Array) warp-specialized block-scaled matmul을 구현한다. 두 CTA가 하나의 output tile을 협력 처리하여 SMEM 사용량을 줄이고 arithmetic intensity를 높이며, mxfp8, mxfp4, nvfp4 세 가지 포맷을 모두 지원한다.
핵심 코드 분석
1. 2CTA 협력 구조와 CGA Layout
2CTA 모드의 핵심은 CGA(Cooperative Group Array) layout 설정이다.
def mma_scaled_get_configs(pre_hook=None, cga_layouts=None):
if cga_layouts is None:
cga_layouts = [(), ((1, 0), )]
return [
triton.Config(
{
"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK,
"CGA_LAYOUT": cga_layout,
},
num_warps=4,
num_ctas=2**len(cga_layout),
)
for BM in (128, 256)
...
for cga_layout in cga_layouts
if BM // (2**len(cga_layout)) == 128
]
CGA_LAYOUT=((1, 0),)은 두 CTA가 M 차원을 분할함을 의미한다. tcgen05_mma_scaled는 CTA당 BLOCK_M=128을 요구하므로, 2CTA 모드에서는 BLOCK_M=256이 된다. B operand는 두 CTA가 공유하여 TMA 로드를 절반으로 줄인다.
2. Shared Operand 전략 — B Scale Multicast
tma.async_copy_global_to_shared(
b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar,
b_scale_bufs.index(index), pred, multicast=multicast_b_scale
)
B 행렬과 B scale은 두 CTA가 동일한 데이터를 필요로 하므로 TMA multicast로 한 번만 로드한다. CGA layout에서 B scale의 split 차원은 [[0, 0, 0, 0, 0]] (분할 없음)으로 설정된다.
3. Online MMA with Scale Unswizzle
Scale factor는 TMA의 packed block format으로 로드되므로, SMEM에서 unswizzle이 필요하다.
@gluon.jit
def unswizzle_scales_shared_memory(smem, BLOCK_MN, BLOCK_K, VEC_SIZE):
smem = smem.reshape((smem.shape[1], smem.shape[2], 32, 4, 4))
smem = smem.permute((0, 3, 2, 1, 4))
return smem.reshape((BLOCK_MN, BLOCK_K // VEC_SIZE))
Unswizzle 후 scale을 tensor memory에 복사하고, tcgen05_mma_scaled로 scaled MMA를 실행한다.
tcgen05_mma_scaled(
a_smem, b_smem.permute((1, 0)), acc_tmem,
a_scale_tmem, b_scale_tmem,
a_format, b_format,
use_acc=use_acc, pred=pred
)
4. Planar Snake Tile Scheduling
타일 스케줄링은 L2 cache 지역성을 높이는 planar snake 패턴을 사용한다.
@gluon.jit
def _planar_snake(lin_idx, m_tiles, n_tiles, minor_dim, tile_width):
minor_tile_idx = lin_idx // (tile_width * major_size)
full_major = gl.where(
(minor_tile_idx % 2) == 0,
full_major_within,
major_size - 1 - full_major_within
)
짝수 minor tile에서는 정방향, 홀수에서는 역방향으로 major 차원을 순회하여 인접 타일 간 데이터 재사용을 극대화한다.
왜 이게 좋은가
- 2CTA 협력: 단일 CTA 대비 B operand 로드를 절반으로 줄여 메모리 bandwidth 효율을 높인다.
- 3가지 포맷 지원: mxfp8 (
e4m3), mxfp4 (e2m1), nvfp4를 하나의 커널에서 처리한다.VEC_SIZE가 포맷에 따라 16 또는 32로 자동 설정된다. - Autotuning:
BLOCK_M/N/K,num_buffers,epilogue_n,grid_tile_width등 다양한 파라미터 조합을 자동 탐색한다. - cuBLAS와의 직접 비교를 포함하여 Triton Gluon의 경쟁력을 검증하는 레퍼런스 구현이다.
정리
- Blackwell의 tcgen05 MMA, TMA multicast, tensor memory 등 최신 하드웨어 기능을 Triton에서 활용하는 방법을 보여준다.
- Warp specialization (load partition / MMA partition 분리)은 high-throughput 커널 설계의 핵심 패턴이다.
- Scale factor의 packed block format과 unswizzle 과정은 FP4/FP8 커널 작성 시 반드시 이해해야 하는 부분이다.
참고 자료
- Triton Gluon 공식 예제 — Gluon API 사용 예제
- NVIDIA Blackwell Architecture Whitepaper — Blackwell GPU 아키텍처 소개
- MX (Microscaling) Data Formats Specification — mxfp4/mxfp8 포맷 명세
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Ultralytics] 캘리브레이션 데이터셋이 배치보다 작을 때 에러 대신 자동 조정
- 현재글 : [triton] Triton 2CTA Block-Scaled Matmul — cuBLAS 대비 성능 비교
- 다음글 [Ray] 메모리 압력 테스트의 로그 패턴 업데이트로 테스트 안정성 확보
댓글