[onnxruntime] RISC-V 벡터(RVV) 최적화: ONNX Runtime LLM 추론 성능 극대화
PR 링크: microsoft/onnxruntime#28518 상태: Merged | 변경: +1758 / -5
들어가며
최근 RISC-V 아키텍처가 서버 및 엣지 디바이스에서 주목받으면서, 고성능 연산 처리를 위한 벡터 확장(RVV, RISC-V Vector Extension)의 중요성이 커지고 있습니다. 본 PR은 microsoft/onnxruntime의 MLAS(Microsoft Linear Algebra Subprograms) 라이브러리에 RVV 기반의 최적화된 LLM 연산자를 도입하여, SpacemiT K3 CPU 환경에서 LLM 추론 성능을 획기적으로 개선한 사례입니다.
코드 분석
1. MLAS 플랫폼 추상화 및 커널 등록
cmake/onnxruntime_mlas.cmake와 onnxruntime/core/mlas/lib/platform.cpp를 통해 RVV 커널을 동적으로 디스패치할 수 있는 구조를 마련했습니다. 특히 MlasLayerNormF32와 같은 새로운 API를 추가하여, 플랫폼별 최적화된 구현체를 호출할 수 있도록 설계되었습니다.
2. FP16 GEMM 최적화 (riscv64/halfgemm_kernel_rvv.cpp)
FP16 GEMM 연산에서 vfloat16m4_t를 사용하여 벡터 레지스터 그룹을 효율적으로 활용했습니다.
// Before: Scalar implementation (Baseline)
// After: RVV optimized kernel
vfloat16m4_t acc0, acc1, acc2, acc3;
// ... inner loop using vfmacc_vf ...
리뷰 과정에서 레지스터 압박(Register Pressure) 문제가 제기되었으나, GCC 15 컴파일러 환경에서 4개의 누산기(Accumulator)를 사용하더라도 벡터 레지스터 할당이 최적화됨을 확인했습니다.
3. LayerNorm 및 RMSNorm 최적화
기존의 스칼라 기반 구현을 RVV로 대체하여 연산 속도를 6배 이상 향상시켰습니다.
// After: RVV optimized LayerNorm
vfloat32m4_t vy = __riscv_vfmul_vf_f32m4(vx, inv_denom, vl);
vy = __riscv_vfmul_vv_f32m4(vy, vs, vl);
__riscv_vse32_v_f32m4(Output + i, vy, vl);
왜 이게 좋은가
성능 수치
- FP16 GEMM: 128x3072x768 크기에서 기존 스칼라 대비 약 187.8배의 속도 향상을 기록했습니다.
- RMSNorm: 4096 Hidden 차원에서 6.6배의 성능 향상을 보였습니다.
최적화 교훈
- 런타임 기능 감지:
__riscv_hwprobe를 사용하여 Zvfh(FP16 벡터) 지원 여부를 런타임에 확인함으로써, 단일 바이너리로 다양한 RISC-V 하드웨어를 지원하는 안정성을 확보했습니다. - 데이터 정렬 및 인터리빙: RoPE 연산 시
vlse32대신vlseg2e32를 사용하여 메모리 접근 패턴을 최적화함으로써 성능을 추가로 43% 개선했습니다. - API 설계: CPU EP(Execution Provider) 소스 코드에 직접적인 벡터 명령어를 삽입하는 대신, MLAS 라이브러리에 추상화된 API를 제공하여 유지보수성을 높였습니다.
리뷰어 피드백 반영
리뷰어 hariharans29는 런타임 크래시 방지를 위한 Zvfh 프로브 도입과, SimplifiedLayerNorm에서 Bias 처리 시 발생할 수 있는 정밀도 문제를 지적했습니다. 이에 대해 런타임 체크 로직을 추가하고, assert를 통한 안전장치를 마련하여 코드의 견고함을 높였습니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://github.com/riscv/riscv-v-spec
- https://onnxruntime.ai/docs/performance/model-optimizations.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [onnxruntime] ONNX Runtime의 RISC-V Vector(RVV) 최적화: SGEMM과 Softmax 성능을 3배로 끌어올리기
- [flashinfer] FlashInfer FP8 KV-Cache Prefill 성능 최적화: Repacking 기법을 통한 오버헤드 제거
- [sglang] SGLang의 NIXL 통신 최적화: Prep+Make API 도입을 통한 KV 캐시 전송 성능 향상
- [sglang] SGLang 스케줄러 최적화: input_ids H2D 지연 처리 및 FutureMap 통합
- [onnxruntime] ONNX Runtime의 CPU GQA 최적화: Flash Attention과 Flash Decoding 도입
PR Analysis 의 다른글
- 이전글 [vllm] vLLM 기술 딥다이브: CUTLASS를 활용한 NVFP4 Linear 커널의 Batch Invariance 최적화
- 현재글 : [onnxruntime] RISC-V 벡터(RVV) 최적화: ONNX Runtime LLM 추론 성능 극대화
- 다음글 [vllm] vLLM XPU MOE 성능 최적화: 호스트 오버헤드 감소를 위한 객체 지향적 접근
댓글