[onnxruntime] ONNX Runtime의 RISC-V Vector(RVV) 최적화: SGEMM과 Softmax 성능을 3배로 끌어올리기
PR 링크: microsoft/onnxruntime#28261 상태: Merged | 변경: +None / -None
들어가며
최근 오픈 소스 하드웨어 아키텍처인 RISC-V의 부상과 함께, 엣지 디바이스 및 고성능 컴퓨팅 분야에서 RISC-V의 활용도가 높아지고 있습니다. 하지만 그동안 ONNX Runtime의 CPU 실행 프로바이더(Execution Provider) 내 핵심 연산 라이브러리인 MLAS(Microsoft Linear Algebra Subprograms)는 RISC-V 환경에서 최적화된 벡터 명령어를 충분히 활용하지 못하고 스칼라(Scalar) 코드에 의존해 왔습니다.
이번 PR은 RISC-V의 벡터 확장인 RVV(RISC-V Vector)를 지원하여, 딥러닝의 핵심 연산인 SGEMM(Single-precision General Matrix Multiplication)과 Softmax의 성능을 비약적으로 향상시킨 사례입니다. 벤치마크 결과에 따르면 SGEMM은 약 3.6배, Softmax는 약 3.2배의 성능 향상을 달성했습니다.
코드 분석: RVV 지원을 위한 구조적 변화
1. 빌드 시스템 및 환경 감지 (CMake)
가장 먼저 수행된 작업은 컴파일러가 RVV 명령어를 지원하는지 확인하고, 적절한 컴파일 플래그를 설정하는 것입니다. 특히 리뷰어(Copilot)의 피드백을 반영하여 기존 CMAKE_REQUIRED_FLAGS를 덮어쓰지 않고 보존했다가 복구하는 안전한 방식을 채택했습니다.
Before (기존 방식의 위험성):
# 단순히 플래그를 설정하면 기존에 설정된 다른 중요 플래그가 유실될 수 있음
set(CMAKE_REQUIRED_FLAGS "-march=rv64gcv -mabi=lp64d")
After (개선된 방식):
# cmake/onnxruntime_mlas.cmake
if(onnxruntime_USE_RVV)
set(OLD_CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS}")
set(CMAKE_REQUIRED_FLAGS "${OLD_CMAKE_REQUIRED_FLAGS} -march=rv64gcv -mabi=lp64d")
check_cxx_source_compiles("
#include <stddef.h>
#include <riscv_vector.h>
int main() {
size_t vl = __riscv_vsetvl_e32m1(4);
return static_cast<int>(vl == 0);
}"
HAS_RISCV64_RVV
)
set(CMAKE_REQUIRED_FLAGS "${OLD_CMAKE_REQUIRED_FLAGS}")
# ...
endif()
2. MLAS 추상화 계층 연결
MLAS는 다양한 아키텍처(x86, ARM, SVE 등)를 지원하기 위해 함수 포인터 기반의 플랫폼 추상화 구조를 가집니다. 이번 변경을 통해 MLAS_TARGET_RISCV64 타겟에서도 RVV 전용 커널을 호출할 수 있도록 구조를 확장했습니다.
onnxruntime/core/mlas/lib/mlasi.h:
#elif defined(MLAS_TARGET_RISCV64)
#if defined(MLAS_USE_RVV)
MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelRvv;
void MlasSgemmCopyPackBRvv(
float* D,
const float* B,
size_t ldb,
size_t CountX,
size_t CountY);
#endif
리뷰어의 제안에 따라 MLAS_USE_RVV 매크로를 사용하여, RVV가 활성화되지 않은 환경에서 불필요한 심볼이 참조되는 것을 방지했습니다.
3. Softmax 연산의 벡터화
Softmax 연산은 ReduceMaximum, ComputeSumExp, ComputeSoftmaxOutput의 세 단계로 나뉩니다. 기존에는 이를 스칼라 루프 내에서 처리했으나, 이제는 RVV의 vsetvl(Vector Length 설정)과 벡터 연산 명령어를 통해 병렬 처리됩니다.
onnxruntime/core/mlas/lib/compute.cpp:
// Softmax의 각 단계에서 RVV 타겟이 정의되어 있으면 전용 커널을 호출
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE) || defined(MLAS_TARGET_RISCV64)
Maximum = GetMlasPlatform().ReduceMaximumF32Kernel(Input, D);
#else
Maximum = MlasReduceMaximumF32Kernel(Input, D);
#endif
왜 이게 좋은 최적화인가?
1. 하드웨어 가변 벡터 길이(VLA) 대응
RISC-V Vector의 특징은 하드웨어마다 벡터 레지스터의 길이(VLEN)가 다를 수 있다는 점입니다. 이번에 추가된 sgemm_kernel_rvv.cpp와 softmax_kernel_rvv.cpp는 __riscv_vsetvl 명령어를 사용하여 런타임에 최적의 벡터 길이를 결정하므로, 다양한 RISC-V 구현체에서 재컴파일 없이 최적의 성능을 낼 수 있습니다.
2. 메모리 대역폭 최적화 (SGEMM Pack B)
SGEMM 성능 향상의 핵심 중 하나는 sgemm_pack_b_rvv.cpp의 도입입니다. 행렬 곱셈 시 메모리 액세스 패턴을 최적화하기 위해 B 행렬을 재배치(Packing)하는 과정에도 RVV를 적용하여, 전체 End-to-end 속도를 크게 개선했습니다.
3. 성능 수치로 증명된 결과
| 연산 | 케이스 | 속도 향상 (Speedup) |
|---|---|---|
| SGEMM | 128x3072x768 | 3.62x (Compute) |
| Softmax | 4096x128 | 3.20x |
단순히 코드를 추가한 것에 그치지 않고, 스칼라 대비 3배 이상의 성능 향상을 이끌어냄으로써 RISC-V 기반 AI 추론의 실용성을 입증했습니다.
일반적인 교훈
- Transitive Include의 위험성: 리뷰 과정에서
softmax_kernel_rvv.cpp에<limits>헤더가 누락된 점이 지적되었습니다. 특정 환경에서 우연히 컴파일되더라도, 다른 툴체인에서는 깨질 수 있으므로 사용하는 API의 헤더는 명시적으로 포함해야 합니다. - 환경 변수 파싱의 엄격함:
ORT_MLAS_RISCV_FORCE_SCALAR와 같은 디버깅용 플래그를 처리할 때, 단순히 '0'이 아니면 참으로 간주하는 방식은 위험할 수 있습니다.true/false,1/0등 명확한 불리언 값을 처리하는 로직이 견고한 소프트웨어를 만듭니다. - 추상화 계층의 중요성: MLAS와 같이 잘 설계된 플랫폼 추상화 계층이 있었기에, 새로운 아키텍처(RVV) 지원을 기존 로직의 큰 수정 없이 깔끔하게 주입할 수 있었습니다.
이 PR은 향후 RISC-V 생태계에서 ONNX Runtime이 표준 추론 엔진으로 자리 잡는 데 중요한 초석이 될 것입니다.
참고 자료
- https://github.com/riscv_non_isa/rvv-intrinsic-doc
- https://cmake.org/cmake/help/latest/module/CheckCXXSourceCompiles.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] Triton의 Ragged Matmul 메타데이터 계산 최적화: CPU 동기화 없는 효율적인 프로파일링
- 현재글 : [onnxruntime] ONNX Runtime의 RISC-V Vector(RVV) 최적화: SGEMM과 Softmax 성능을 3배로 끌어올리기
- 다음글 [vllm] vLLM chunk_kda 커널의 숨겨진 상태(h) 레이아웃 불일치 버그 수정 및 정확도 개선
댓글