[triton] Triton에서 cuBLAS를 활용한 mxfp8 및 nvfp4 블록 스케일 행렬 곱셈 벤치마킹
PR 링크: triton-lang/triton#9044 상태: Merged | 변경: +565 / -71
들어가며
최신 AI 모델 학습 및 추론에서 연산 효율성을 극대화하기 위해 FP8, FP4와 같은 저정밀도 데이터 포맷과 블록 스케일링(Block-scaling) 기법이 활발히 도입되고 있습니다. Triton은 이러한 연산을 위한 커스텀 커널 작성을 지원하지만, 작성된 커널이 하드웨어의 잠재력을 얼마나 잘 활용하고 있는지 판단하기 위해서는 신뢰할 수 있는 베이스라인이 필수적입니다. 이번 PR은 NVIDIA의 cuBLAS 라이브러리를 활용하여 mxfp8 및 nvfp4 포맷에 대한 블록 스케일 행렬 곱셈 베이스라인을 구축하고, 기존 튜토리얼을 개선하여 성능 비교 환경을 마련했습니다.
코드 분석
1. python/tutorials/10-block-scaled-matmul.py 개선
기존 튜토리얼 코드에 cuBLAS를 연동하기 위한 cublas_block_scaled_matmul 함수가 추가되었습니다. 이 함수는 입력 데이터 포맷에 따라 적절한 cuBLAS API를 호출합니다.
def cublas_block_scaled_matmul(a, a_scale, b, b_scale, block_scale_type="mxfp8"):
# ... (생략)
if block_scale_type == "mxfp8":
output = torch.empty((M, N), dtype=torch.float16, device="cuda")
cublas.block_scaled_matmul_mxfp8(a, b, output, a_scale, b_scale)
elif block_scale_type == "nvfp4":
output = torch.empty((M, N), dtype=torch.float16, device="cuda")
cublas.block_scaled_matmul_nvfp4(a, b, output, a_scale, b_scale)
return output
또한, 벤치마크의 정확도를 높이기 위해 initialize_block_scaled 함수 내에서 스케일 텐서를 cuBLAS가 요구하는 레이아웃으로 변환하는 로직을 통합했습니다.
2. python/test/unit/runtime/test_blaslt.py 테스트 추가
새로운 기능이 올바르게 동작하는지 검증하기 위해 test_block_scaled_matmul_mxfp8 및 test_block_scaled_matmul_nvfp4 테스트 케이스가 추가되었습니다. 특히 Blackwell 아키텍처(compute capability 10.0)에서만 동작하도록 제약 조건을 설정하여 하드웨어 호환성을 보장했습니다.
def supports_block_scaling():
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
왜 이게 좋은가
이번 최적화의 핵심은 **'벤치마킹의 신뢰성 확보'**입니다. 단순히 Triton 커널의 속도만 측정하는 것이 아니라, NVIDIA에서 최적화한 cuBLAS 라이브러리와 직접 비교함으로써 다음과 같은 이점을 얻습니다.
- 성능 격차 가시화: Triton 커널이 하드웨어 가속기를 얼마나 효율적으로 사용하는지 즉각적으로 파악할 수 있습니다.
- 최적화 방향 제시: cuBLAS 대비 성능이 낮다면, Triton 커널의 메모리 액세스 패턴이나 스케줄링 전략을 수정해야 할 명확한 근거가 됩니다.
- 표준화된 벤치마크: 튜토리얼에 warmup 반복을 추가하고 cuBLAS 베이스라인을 포함함으로써, 커뮤니티 기여자들이 보다 공정한 환경에서 성능을 측정할 수 있게 되었습니다.
일반적으로 고성능 컴퓨팅(HPC) 분야에서는 커스텀 커널을 작성할 때 항상 라이브러리 기반의 베이스라인을 함께 측정하는 것이 권장됩니다. 이는 하드웨어의 특수 기능을 얼마나 잘 활용하고 있는지 확인하는 가장 빠른 방법이기 때문입니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Loki] Partition Ring Shuffle Sharding에 LRU 캐시 도입
- 현재글 : [triton] Triton에서 cuBLAS를 활용한 mxfp8 및 nvfp4 블록 스케일 행렬 곱셈 벤치마킹
- 다음글 [Open WebUI] FileMetadataResponse의 meta 필드를 Optional로 변경하여 배치 추가 오류 수정
댓글