[triton] Triton의 Ragged Matmul 메타데이터 계산 최적화: CPU 동기화 없는 효율적인 프로파일링
PR 링크: triton-lang/triton#10150 상태: Merged | 변경: +None / -None
들어가며
Triton의 Proton 프로파일러는 연산의 FLOPS와 메모리 사용량을 추적하기 위해 matmul launch 메타데이터를 수집합니다. 특히 'Ragged matmul'(입력 크기가 일정하지 않은 행렬 곱셈)의 경우, 기존에는 CPU 동기화 없이 메타데이터를 계산하기 위해 여러 개의 작은 Torch GPU 연산(aten.mul, aten.sum 등)을 순차적으로 실행했습니다. 이 방식은 결과값은 정확하지만, 다수의 커널을 실행하는 과정에서 불필요한 메타데이터 오버헤드를 발생시켰습니다. 본 PR은 이 과정을 단일 Triton 커널로 통합하여 프로파일링 성능을 크게 개선했습니다.
코드 분석
메타데이터 계산 커널 도입
핵심 변경 사항은 _matmul_flops_and_bytes_from_slices_kernel의 도입입니다. 기존에는 CPU에서 slice_sizes를 다루는 여러 Torch 연산을 수행했으나, 이제는 GPU 상에서 직접 계산을 수행합니다.
# 기존 방식 (다수의 Torch 커널 호출)
n_tokens = slice_sizes.sum()
flops = n_tokens.to(torch.float64) * (2.0 * M * N * z)
n_x_bytes = n_tokens * X.shape[-2] * X.element_size()
# ... (이후 여러 연산 추가)
# 개선된 방식 (Triton 커널 호출)
_matmul_flops_and_bytes_from_slices(args, M, N, K, X, Y, W, slice_sizes, nbits, batch_size)
루프를 통한 확장성 확보
리뷰어의 피드백을 반영하여, slice_sizes가 매우 큰 경우를 대비해 BLOCK_SIZE 단위로 나누어 처리하는 루프 구조를 도입했습니다. 이는 커널의 범용성을 높이고 메모리 제한을 우회하는 핵심적인 설계입니다.
# 리뷰어 제안을 반영한 루프 처리
for i in range(0, NUM_SLICES, BLOCK_SIZE):
offs = i + tl.arange(0, BLOCK_SIZE)
# ... 누적 연산 수행
왜 이게 좋은가
성능 개선 수치
벤치마크 결과, Ragged M 모드에서 약 4.8배, Ragged K 모드에서 약 3배의 GPU 속도 향상을 보였습니다. 이는 프로파일링 자체가 실제 연산 성능에 미치는 영향을 최소화하여, 더 정확한 성능 분석을 가능하게 합니다.
| Mode | Old GPU (us) | New GPU (us) | Speedup |
|---|---|---|---|
| ragged_m | 136.131 | 28.020 | 4.86x |
| ragged_k | 89.660 | 29.318 | 3.06x |
교훈
- 커널 통합의 힘: 다수의 작은 커널을 호출하는 것보다, 하나의 특화된 커널을 실행하는 것이 커널 런칭 오버헤드(Kernel Launch Overhead)를 줄이는 데 훨씬 효과적입니다.
- 동기화 최소화: CPU와 GPU 간의 잦은 동기화는 성능의 적입니다. 메타데이터 계산과 같은 보조적인 작업도 GPU 내에서 완결성을 갖도록 설계하는 것이 중요합니다.
- 확장성 고려:
BLOCK_SIZE를 활용한 루프 처리는 고정된 크기의 배열뿐만 아니라 가변적인 입력 데이터셋에 대해서도 커널이 안정적으로 동작하게 합니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [sglang] FlashInfer TRTLLM-Gen MoE 커널 최적화: NemotronH 모델 지원 및 성능 향상
- 현재글 : [triton] Triton의 Ragged Matmul 메타데이터 계산 최적화: CPU 동기화 없는 효율적인 프로파일링
- 다음글 [onnxruntime] ONNX Runtime의 RISC-V Vector(RVV) 최적화: SGEMM과 Softmax 성능을 3배로 끌어올리기
댓글