[onnxruntime] ONNX Runtime CPU ScatterElements 커널의 멀티스레딩 최적화 분석
PR 링크: microsoft/onnxruntime#28588 상태: Merged | 변경: +138 / -93
들어가며
ONNX Runtime(ORT)의 ScatterElements 연산은 텐서의 특정 축을 따라 값을 분산시키는 중요한 연산입니다. 하지만 기존 구현은 완전히 순차적(sequential)으로 동작하여, 멀티코어 환경에서도 단일 스레드만 사용하는 비효율적인 구조였습니다. 특히 ARM 기반 24코어 시스템에서 약 761ms가 소요되던 작업이 최적화 후 6ms로 단축되는 등, 성능 병목이 심각했습니다. 본 글에서는 ThreadPool을 도입하여 이 문제를 어떻게 해결했는지 살펴봅니다.
코드 분석
1. GetIndices 병렬화
인덱스 유효성 검사 및 정규화 과정은 기존에 단일 루프로 수행되었습니다. 이를 concurrency::ThreadPool::TryParallelFor를 사용하여 병렬화했습니다.
Before:
for (int64_t i = 0; i < num_indices; ++i) {
const int64_t idx = static_cast<int64_t>(indices_data_raw[i]);
// ... 유효성 검사 및 변환
}
After:
concurrency::ThreadPool::TryParallelFor(
tp, narrow<std::ptrdiff_t>(num_indices), 1.0,
[&](std::ptrdiff_t first, std::ptrdiff_t last) {
for (std::ptrdiff_t i = first; i < last; ++i) {
// ... 병렬 처리
}
});
여기서 std::atomic을 사용하여 여러 스레드에서 오류를 감지할 때 발생할 수 있는 경쟁 상태(race condition)를 안전하게 처리했습니다.
2. ScatterData 병렬화
핵심은 axis를 기준으로 outer_size * inner_size만큼의 독립적인 작업 단위로 분해하는 것입니다. 각 작업 단위는 서로 다른 메모리 영역에 쓰기 작업을 수행하므로 락(lock) 없이 병렬 처리가 가능합니다.
After (핵심 로직):
concurrency::ThreadPool::TryParallelFor(
tp, narrow<std::ptrdiff_t>(total_work_units), static_cast<double>(axis_size),
[&](std::ptrdiff_t first, std::ptrdiff_t last) {
for (std::ptrdiff_t work_idx = first; work_idx < last; ++work_idx) {
// ... 오프셋 계산 및 scatter 수행
}
});
왜 이게 좋은가
이 최적화의 핵심 교훈은 '데이터 의존성 분석을 통한 락 프리(lock-free) 병렬화'입니다. ScatterElements 연산은 특정 축을 기준으로 작업 단위가 서로 겹치지 않음을 보장할 수 있습니다. 이를 통해 복잡한 동기화 기법 없이도 선형적인 성능 향상을 얻었습니다.
- 성능 수치: 24코어 ARM 시스템 기준 761ms → 6ms (약 129배 향상).
- 교훈:
TryParallelFor와 같은 추상화된 스레드 풀을 활용하면 플랫폼 종속적인 스레드 관리를 피할 수 있습니다.static_cast대신onnxruntime::narrow를 사용하여 32비트/64비트 환경 간의 타입 변환 안전성을 확보하는 것이 중요합니다.- 에러 처리를 위한
std::atomic사용 시, 성능 저하를 최소화하기 위해 '최초 발견' 방식의 비결정적 보고를 허용하는 트레이드오프 전략이 유효합니다.
리뷰어 피드백 반영
- 타입 안전성:
static_cast<std::ptrdiff_t>사용 시 발생할 수 있는 오버플로우 문제를narrow<std::ptrdiff_t>로 수정하여 32비트 빌드 안정성을 확보했습니다. - 에러 보고: 병렬 처리 중 에러 발생 시 결정론적(deterministic) 순서를 보장하기 위한 추가 오버헤드 대신, 비결정적이지만 효율적인 방식을 채택하고 이를 주석으로 명시했습니다.
참고 자료
- https://onnxruntime.ai/docs/api/c/struct_onnx_runtime_1_1concurrency_1_1_thread_pool.html
- https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterElements
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [transformers] Hugging Face Transformers: SequenceFeatureExtractor.pad() 최적화로 불필요한 NumPy 배열 재변환 제거
- [uv] uv의 로컬 휠(Wheel) 압축 해제 성능 회귀 문제 해결: astral_async_zip 버전 업데이트
- [cpython] tarfile 스트리밍 모드(r|*) 성능 개선: 파이썬 압축 파일 처리의 숨겨진 병목 제거
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [feast] Feast Feature Server의 직렬화 성능 4배 향상: MessageToDict 최적화
PR Analysis 의 다른글
- 이전글 [cpython] Python JIT 최적화: 트레이스 버퍼 오버헤드 관리 개선
- 현재글 : [onnxruntime] ONNX Runtime CPU ScatterElements 커널의 멀티스레딩 최적화 분석
- 다음글 [ultralytics] Ultralytics 코드베이스 경량화: SciPy 의존성 감소 및 NumPy 기반 최적화
댓글