본문으로 건너뛰기

[onnxruntime] ONNX Runtime의 RISC-V Vector(RVV) 최적화: SGEMM과 Softmax 성능을 3배로 끌어올리기

PR 링크: microsoft/onnxruntime#28261 상태: Merged | 변경: +None / -None

들어가며

최근 오픈 소스 하드웨어 아키텍처인 RISC-V의 부상과 함께, 엣지 디바이스 및 고성능 컴퓨팅 분야에서 RISC-V의 활용도가 높아지고 있습니다. 하지만 그동안 ONNX Runtime의 CPU 실행 프로바이더(Execution Provider) 내 핵심 연산 라이브러리인 MLAS(Microsoft Linear Algebra Subprograms)는 RISC-V 환경에서 최적화된 벡터 명령어를 충분히 활용하지 못하고 스칼라(Scalar) 코드에 의존해 왔습니다.

이번 PR은 RISC-V의 벡터 확장인 RVV(RISC-V Vector)를 지원하여, 딥러닝의 핵심 연산인 SGEMM(Single-precision General Matrix Multiplication)Softmax의 성능을 비약적으로 향상시킨 사례입니다. 벤치마크 결과에 따르면 SGEMM은 약 3.6배, Softmax는 약 3.2배의 성능 향상을 달성했습니다.


코드 분석: RVV 지원을 위한 구조적 변화

1. 빌드 시스템 및 환경 감지 (CMake)

가장 먼저 수행된 작업은 컴파일러가 RVV 명령어를 지원하는지 확인하고, 적절한 컴파일 플래그를 설정하는 것입니다. 특히 리뷰어(Copilot)의 피드백을 반영하여 기존 CMAKE_REQUIRED_FLAGS를 덮어쓰지 않고 보존했다가 복구하는 안전한 방식을 채택했습니다.

Before (기존 방식의 위험성):

# 단순히 플래그를 설정하면 기존에 설정된 다른 중요 플래그가 유실될 수 있음
set(CMAKE_REQUIRED_FLAGS "-march=rv64gcv -mabi=lp64d")

After (개선된 방식):

# cmake/onnxruntime_mlas.cmake
if(onnxruntime_USE_RVV)
  set(OLD_CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS}")
  set(CMAKE_REQUIRED_FLAGS "${OLD_CMAKE_REQUIRED_FLAGS} -march=rv64gcv -mabi=lp64d")
  check_cxx_source_compiles("
    #include <stddef.h>
    #include <riscv_vector.h>
    int main() {
      size_t vl = __riscv_vsetvl_e32m1(4);
      return static_cast<int>(vl == 0);
    }"
    HAS_RISCV64_RVV
  )
  set(CMAKE_REQUIRED_FLAGS "${OLD_CMAKE_REQUIRED_FLAGS}")
  # ...
endif()

2. MLAS 추상화 계층 연결

MLAS는 다양한 아키텍처(x86, ARM, SVE 등)를 지원하기 위해 함수 포인터 기반의 플랫폼 추상화 구조를 가집니다. 이번 변경을 통해 MLAS_TARGET_RISCV64 타겟에서도 RVV 전용 커널을 호출할 수 있도록 구조를 확장했습니다.

onnxruntime/core/mlas/lib/mlasi.h:

#elif defined(MLAS_TARGET_RISCV64)
#if defined(MLAS_USE_RVV)
    MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelRvv;
    void MlasSgemmCopyPackBRvv(
        float* D,
        const float* B,
        size_t ldb,
        size_t CountX,
        size_t CountY);
#endif

리뷰어의 제안에 따라 MLAS_USE_RVV 매크로를 사용하여, RVV가 활성화되지 않은 환경에서 불필요한 심볼이 참조되는 것을 방지했습니다.

3. Softmax 연산의 벡터화

Softmax 연산은 ReduceMaximum, ComputeSumExp, ComputeSoftmaxOutput의 세 단계로 나뉩니다. 기존에는 이를 스칼라 루프 내에서 처리했으나, 이제는 RVV의 vsetvl(Vector Length 설정)과 벡터 연산 명령어를 통해 병렬 처리됩니다.

onnxruntime/core/mlas/lib/compute.cpp:

// Softmax의 각 단계에서 RVV 타겟이 정의되어 있으면 전용 커널을 호출
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE) || defined(MLAS_TARGET_RISCV64)
    Maximum = GetMlasPlatform().ReduceMaximumF32Kernel(Input, D);
#else 
    Maximum = MlasReduceMaximumF32Kernel(Input, D);
#endif

왜 이게 좋은 최적화인가?

1. 하드웨어 가변 벡터 길이(VLA) 대응

RISC-V Vector의 특징은 하드웨어마다 벡터 레지스터의 길이(VLEN)가 다를 수 있다는 점입니다. 이번에 추가된 sgemm_kernel_rvv.cppsoftmax_kernel_rvv.cpp__riscv_vsetvl 명령어를 사용하여 런타임에 최적의 벡터 길이를 결정하므로, 다양한 RISC-V 구현체에서 재컴파일 없이 최적의 성능을 낼 수 있습니다.

2. 메모리 대역폭 최적화 (SGEMM Pack B)

SGEMM 성능 향상의 핵심 중 하나는 sgemm_pack_b_rvv.cpp의 도입입니다. 행렬 곱셈 시 메모리 액세스 패턴을 최적화하기 위해 B 행렬을 재배치(Packing)하는 과정에도 RVV를 적용하여, 전체 End-to-end 속도를 크게 개선했습니다.

3. 성능 수치로 증명된 결과

연산 케이스 속도 향상 (Speedup)
SGEMM 128x3072x768 3.62x (Compute)
Softmax 4096x128 3.20x

단순히 코드를 추가한 것에 그치지 않고, 스칼라 대비 3배 이상의 성능 향상을 이끌어냄으로써 RISC-V 기반 AI 추론의 실용성을 입증했습니다.


일반적인 교훈

  1. Transitive Include의 위험성: 리뷰 과정에서 softmax_kernel_rvv.cpp<limits> 헤더가 누락된 점이 지적되었습니다. 특정 환경에서 우연히 컴파일되더라도, 다른 툴체인에서는 깨질 수 있으므로 사용하는 API의 헤더는 명시적으로 포함해야 합니다.
  2. 환경 변수 파싱의 엄격함: ORT_MLAS_RISCV_FORCE_SCALAR와 같은 디버깅용 플래그를 처리할 때, 단순히 '0'이 아니면 참으로 간주하는 방식은 위험할 수 있습니다. true/false, 1/0 등 명확한 불리언 값을 처리하는 로직이 견고한 소프트웨어를 만듭니다.
  3. 추상화 계층의 중요성: MLAS와 같이 잘 설계된 플랫폼 추상화 계층이 있었기에, 새로운 아키텍처(RVV) 지원을 기존 로직의 큰 수정 없이 깔끔하게 주입할 수 있었습니다.

이 PR은 향후 RISC-V 생태계에서 ONNX Runtime이 표준 추론 엔진으로 자리 잡는 데 중요한 초석이 될 것입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글