[triton] Triton Gluon을 활용한 고성능 2CTA 블록 스케일 행렬 곱셈 최적화
PR 링크: triton-lang/triton#9697 상태: Merged | 변경: +990 / -0
들어가며
최신 GPU 아키텍처에서 행렬 곱셈(MatMul)의 성능을 극대화하기 위해서는 단순히 연산 유닛을 활용하는 것을 넘어, 메모리 계층 구조와 스레드 블록 간의 협업을 정교하게 제어해야 합니다. 이번 Triton 레포지토리의 PR은 Gluon 프레임워크를 사용하여 2개의 CTA(Cooperative Thread Array)가 하나의 출력 타일을 공유하는 '2CTA 워프 전문화(Warp-specialized)' 기법을 도입했습니다. 이를 통해 연산 강도(Arithmetic Intensity)를 높이고, 각 CTA당 필요한 공유 메모리(SMEM) 사용량을 줄여 성능을 크게 향상시켰습니다.
코드 분석
1. 2CTA 협업 및 타일 스케줄링
기존 1CTA 방식과 달리, 2개의 CTA가 협업하기 위해 _planar_snake 스케줄러를 도입하여 타일 접근 순서를 최적화했습니다.
@gluon.jit
def _planar_snake(lin_idx, m_tiles, n_tiles, minor_dim: gl.constexpr, tile_width: gl.constexpr):
# ... (생략) ...
full_major = gl.where((minor_tile_idx % 2) == 0, full_major_within, major_size - 1 - full_major_within)
# ... (생략) ...
이 스케줄러는 캐시 지역성을 극대화하기 위해 지그재그(Snake) 패턴으로 타일을 순회하며, 여러 CTA가 데이터를 효율적으로 공유할 수 있도록 돕습니다.
2. 비동기 행렬 곱셈 구현 (Async MMA)
tcgen05_mma_scaled를 사용하여 Blackwell 아키텍처의 하드웨어 가속 기능을 활용합니다. 특히 스케일링 팩터를 처리하는 unswizzle_scales_shared_memory 함수는 메모리 레이아웃을 재구성하여 병목을 제거합니다.
@gluon.jit
def unswizzle_scales_shared_memory(smem, BLOCK_MN: gl.constexpr, BLOCK_K: gl.constexpr, VEC_SIZE: gl.constexpr):
smem = smem.reshape((smem.shape[1], smem.shape[2], 32, 4, 4))
smem = smem.permute((0, 3, 2, 1, 4))
return smem.reshape((BLOCK_MN, BLOCK_K // VEC_SIZE))
왜 이게 좋은가
성능 개선 수치
이번 최적화를 통해 mxfp8-mxfp8 연산에서 1CTA 대비 최대 15%의 성능 향상을 기록했습니다.
| MNK | 1cta (TFLOPS) | 2cta (TFLOPS) | 2cta/1cta |
|---|---|---|---|
| 8192 | 2525.9 | 2895.0 | 1.15 |
| 16384 | 2409.3 | 2755.0 | 1.14 |
일반적 교훈
- Workload Partitioning: 여러 CTA가 하나의 타일을 공유하게 함으로써 SMEM 압박을 분산하고, 더 큰 블록 사이즈를 효율적으로 처리할 수 있습니다.
- Memory Swizzling: 하드웨어 가속기(Tensor Core)가 선호하는 메모리 레이아웃으로 데이터를 미리 정렬(Unswizzle)하는 과정은 메모리 대역폭 병목을 해결하는 핵심입니다.
- Autotuning: 리뷰 과정에서 언급되었듯,
EPILOGUE_BLOCK_N이나GRID_TILE_WIDTH와 같은 하이퍼파라미터를 자동 튜닝(Autotuning)하여 아키텍처별 최적의 설정을 찾는 것이 필수적입니다.
리뷰어 피드백 반영
리뷰어 lezcano는 allocate_mbarrier 사용을 통한 코드 간결화와 gl.num_ctas()를 활용한 하드코딩 제거를 제안했습니다. 또한 do_bench를 사용하여 더 현실적인 벤치마크 환경을 구성할 것을 권장했으며, 이는 실제 프로덕션 수준의 성능 측정에 중요한 가이드라인이 되었습니다.
참고 자료
- https://triton-lang.org/main/index.html
- https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-copy-instructions
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [PaddleOCR] PaddleOCR-VL 배포 문서 개선 — Docker 이미지 및 디바이스 호환성 가이드 추가
- 현재글 : [triton] Triton Gluon을 활용한 고성능 2CTA 블록 스케일 행렬 곱셈 최적화
- 다음글 [triton] AMD Pipelined Loop에서 TDM Load의 Buffer Race 수정
댓글