[onnxruntime] ONNX Runtime CPU GQA 최적화: INT8/INT4 양자화 KV 캐시와 SIMD 가속
PR 링크: microsoft/onnxruntime#28578 상태: Merged | 변경: +2162 / -19
들어가며
LLM 추론, 특히 디코딩 단계에서 CPU는 메모리 대역폭 제한(Memory-bandwidth-bound) 문제로 인해 성능 병목을 겪습니다. microsoft/onnxruntime의 이번 PR은 GroupQueryAttention(GQA) 연산에서 KV 캐시를 INT8 또는 INT4로 양자화하여 메모리 트래픽을 4~8배 줄이고, 이를 처리하기 위한 하드웨어 가속(AVX2, AVX512-VNNI, NEON) MLAS 커널을 도입했습니다.
코드 분석
1. MLAS 커널 최적화 (qkv_quant_kernel_avx2.cpp)
기존의 스칼라 기반 구현에서 SIMD로 전환하며, 특히 INT4 데이터 처리를 대폭 개선했습니다. 리뷰어의 피드백을 반영하여 스칼라 메모리 스토어/로드 과정을 제거하고 레지스터 내 연산으로 최적화했습니다.
Before (Scalar-based extraction):
// 스칼라 메모리 스토어 후 로드하는 비효율적 방식
alignas(16) int8_t nibbles[8];
// ... scalar extraction logic ...
__m128i v = _mm_loadl_epi64((__m128i*)nibbles);
After (In-register SIMD extraction):
// SSE 비트 연산을 사용하여 레지스터 내에서 직접 처리
__m128i low = _mm_and_si128(packed, _mm_set1_epi8(0x0F));
__m128i high = _mm_srli_epi16(_mm_and_si128(packed, _mm_set1_epi8(0xF0)), 4);
__m128i result = _mm_unpacklo_epi8(low, high);
2. GQA 연산자 검증 로직 (group_query_attention.cc)
양자화된 스케일 텐서의 크기 검증을 강화하여 런타임 크래시를 방지했습니다. 특히 CheckInputs() 이후에 head_size를 계산하여 안전하게 검증하도록 수정되었습니다.
// 수정 전: 랭크 체크 전 인덱싱으로 인한 크래시 위험
// 수정 후:
if (parameters.k_scale != nullptr) {
const auto& scale_dims = parameters.k_scale->Shape().GetDims();
size_t expected_size = (parameters.k_quant_type == QuantType::PER_CHANNEL) ?
(parameters.kv_num_heads * parameters.head_size) : 1;
ORT_ENFORCE(parameters.k_scale->Shape().Size() == expected_size);
}
왜 이게 좋은가
이번 최적화의 핵심은 메모리 대역폭 절감과 연산 처리량(Throughput) 증대입니다.
- 성능 수치: Intel Xeon Platinum 8480C 기준, INT8 Per-Tensor QKGemm 연산에서 AVX512-VNNI 커널 사용 시 스칼라 대비 15.6배의 성능 향상을 보였습니다.
- 유연성:
ORT_MLAS_QKGEMM_S8_APPROX_VNNI=1옵션을 통해 정확도와 속도 사이의 트레이드오프를 선택할 수 있게 설계되었습니다. - 교훈:
- SIMD 최적화 시 스칼라 메모리 접근을 최소화하고 레지스터 내 연산(Shuffle, Unpack)을 활용하는 것이 중요합니다.
- 하드웨어별(AVX2 vs AVX512 vs NEON)로 최적화된 커널을 런타임에 디스패치하는 구조는 이식성과 성능을 동시에 잡는 핵심 전략입니다.
리뷰어 피드백 반영
리뷰 과정에서 헤더 파일 누락(cstring, memory, vector) 문제와 AVX2 벤치마크의 타겟 아키텍처 가드(MLAS_TARGET_AMD64) 누락 등이 지적되었습니다. 작성자는 이를 즉각 반영하여 빌드 안정성을 확보했습니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://onnxruntime.ai/docs/api/c/index.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [flashinfer] FlashInfer의 DeepSeek V4 Sparse MLA 최적화 분석
- 현재글 : [onnxruntime] ONNX Runtime CPU GQA 최적화: INT8/INT4 양자화 KV 캐시와 SIMD 가속
- 다음글 [cpython] Python JIT 최적화: 트레이스 버퍼 오버헤드 관리 개선
댓글