[onnxruntime] [ONNX Runtime] SGEMM의 함정에서 벗어나기: GQA 전용 GEMV 커널을 통한 디코딩 최적화
PR 링크: microsoft/onnxruntime#29216 상태: Merged | 변경: +707 / -39
들어가며
LLM(Large Language Model) 추론 성능의 핵심은 Prefill과 Decode라는 두 가지 서로 다른 단계를 얼마나 효율적으로 처리하느냐에 달려 있습니다. 최근 Microsoft의 onnxruntime 레포지토리에 올라온 PR #29216은 FP32 환경에서 Group Query Attention(GQA)의 디코딩 성능을 획기적으로 개선한 사례입니다.
이 PR이 해결하고자 하는 핵심 문제는 "왜 Flash Attention 기법을 적용했는데 디코딩(Single-token decode) 속도가 더 느려졌는가?"입니다. 이전 PR(#28962)에서 Prefill 단계에는 Flash Attention을 성공적으로 적용했지만, 디코딩 단계에서는 오히려 Naive 방식보다 0.4~0.6배 수준으로 성능이 떨어지는(Regression) 현상이 발생했습니다. 원인은 바로 SGEMM(General Matrix-Matrix Multiplication)의 오버헤드였습니다.
문제 분석: SGEMM vs GEMV
디코딩 단계에서는 쿼리(Query)의 행 개수($M$)가 항상 1입니다. 일반적인 SGEMM 라이브러리는 행렬 연산의 효율을 높이기 위해 데이터를 특정 레이아웃으로 재배치하는 B-packing(Setup) 과정을 거칩니다. 하지만 $M=1$인 경우, 데이터를 재배치하는 데 드는 비용이 실제 연산 비용보다 커지는 배보다 배꼽이 더 큰 상황이 발생합니다.
이 PR은 이러한 SGEMM의 오버헤드를 피하기 위해, $M=1$ 전용 GEMV(General Matrix-Vector Multiplication) 커널을 도입했습니다.
코드 분석: 파일별 핵심 변경 사항
1. 실행 전략의 명시적 구분 (docs/contrib_ops/cpu/gqa.md)
가장 먼저 문서상에서 디코딩 전략이 어떻게 바뀌었는지 확인할 수 있습니다. 기존에는 디코딩 시 Naive 경로로 폴백(Fallback)했으나, 이제는 전용 GEMV 커널을 사용합니다.
Before:
- ... the non-quantized FP32 path is limited to prefill (sequence_length > 1) and uses the naive path for decode.
After:
- ... the non-quantized FP32 path reuses the same tiling, online-softmax, masking, and flash-decoding structure.
- Decode is therefore handled by a dedicated GEMV kernel (MlasGQADecodeGQAThreaded), dispatched whenever sequence_length == 1...
2. 디코딩 모드 결정 로직 (gqa_attention_base.h)
디코딩 상황에서 스레드 풀의 효율을 극대화하기 위해 두 가지 모드를 제안합니다. batch * heads가 스레드 수보다 적으면 KV 캐시를 여러 청크로 나누어 병렬 처리하는 Flash Decoding을 수행하고, 그렇지 않으면 단일 패스(Single-pass) 디코딩을 수행합니다.
After (Logic in gqa_attention_base.h):
// Flash decoding: sequence_length == 1일 때, batch * heads < thread_count인 경우 KV를 스레드 간 분할
const bool use_flash_decoding = (sequence_length == 1 &&
common_past_seqlen >= 0 &&
batch_size * num_heads_ < thread_count &&
kv_chunk_count > 1);
if (use_flash_decoding) {
// Flash decoding용 partials 버퍼 할당 및 GEMV 기반 커널 호출 준비
// ...
} else if (sequence_length == 1) {
// Single-pass GEMV 디코딩 커널 호출
// ...
}
3. SIMD 최적화 GEMV 헬퍼 (flashattn_gqa.cpp)
실제 연산을 담당하는 MlasGQADecodeQK와 MlasGQADecodeSV는 SGEMM을 호출하지 않고 직접 벡터-행렬 곱을 수행합니다. 특히 QK 연산에서는 8개의 accumulator lane을 사용하여 컴파일러가 -ffast-math 옵션 없이도 효율적인 SIMD FMA(Fused Multiply-Add) 인스트럭션을 생성할 수 있도록 유도했습니다.
리뷰어 hariharans29의 질문에 대해 저자 tianleiwu는 다음과 같이 설명합니다:
"GEMV 헬퍼는 $M=1$인 퇴화된(degenerate) 곱셈을 위해 SGEMM의 B-packing/setup 오버헤드 없이 K/V 행을 직접 스트리밍합니다."
왜 이게 좋은 최적화인가?
1. 메모리 대역폭 한계(Memory-Bandwidth Bound)의 이해
디코딩은 연산량보다 메모리에서 데이터를 읽어오는 속도가 병목인 작업입니다. SGEMM은 데이터를 캐시에 최적화된 형태로 '패킹'하려고 시도하지만, 디코딩에서는 각 데이터를 단 한 번만 읽고 버리기 때문에 패킹 자체가 낭비입니다. GEMV 커널은 데이터를 읽는 즉시 연산에 투입하여 불필요한 메모리 쓰기/읽기 사이클을 제거했습니다.
2. 성능 수치 (AMD EPYC 7763 기준)
- Short/Medium Context: Naive 방식과 대등하거나 소폭 우세 (오버헤드 지배적).
- Long Context (T=4097): Naive 대비 1.35~1.5x 성능 향상.
- 이전 Flash Attention 적용 시 발생했던 0.4~0.6x의 성능 저하를 완전히 해결하고 오히려 개선했습니다.
3. 견고한 엔지니어링: Parity Test
리뷰 과정에서 Copilot과 리뷰어들은 Ragged Sequence(배치마다 길이가 다른 경우)에서의 결정론적(Deterministic) 동작 문제를 지적했습니다. 저자는 이를 해결하기 위해 test_gqa_decode_flash_vs_naive_parity 테스트 케이스를 추가하여, Naive 경로와 Flash/GEMV 경로의 결과값이 1e-8 이내로 일치함을 검증했습니다. 이는 최적화 과정에서 발생할 수 있는 수치적 오류를 사전에 차단한 훌륭한 사례입니다.
결론
이번 PR은 단순히 "최신 논문의 알고리즘(Flash Attention)을 적용하는 것"보다 "타겟 워크로드(M=1)의 특성에 맞는 커널(GEMV)을 선택하는 것"이 얼마나 중요한지 보여줍니다. 고성능 라이브러리를 설계할 때는 범용적인 SGEMM에 의존하기보다, 특정 상황에서의 오버헤드를 분석하고 전용 경로(Fast-path)를 구축하는 집요함이 필요합니다.
소프트웨어 엔지니어로서 우리는 항상 도구(SGEMM)의 추상화 뒤에 숨겨진 비용을 인지해야 하며, 특히 성능이 중요한 추론 엔진에서는 데이터의 흐름과 하드웨어의 특성을 일치시키는 최적화가 최고의 가치를 만들어냅니다.
참고 자료
- https://github.com/microsoft/onnxruntime/pull/29216
- https://arxiv.org/abs/2307.08691
- https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [onnxruntime] ONNX Runtime: MoE Router GEMV 최적화 및 Bias Fusion 구현
- [vllm] vLLM, DFlash 도입으로 추론 속도 1.2배 향상: MRV2와 CUDAGraph의 시너지
- [onnxruntime] ONNX Runtime CUDA Graph: 진정한 비동기 추론을 위한 동기화 지점 제거
- [sglang] 성능 최적화의 함정: DeepSeek-V3.2 정확도 붕괴를 막기 위한 SGLang의 긴급 롤백 분석
- [sglang] SGLang 성능 최적화: PDL 도입과 안전한 CUDA 동기화로 DSV3.2/GLM-5 가속하기
PR Analysis 의 다른글
- 이전글 [sglang] SGLang의 Qwen3.5 성능 극대화: Fused QK GemmaRMSNorm + RoPE 커널 최적화 분석
- 현재글 : [onnxruntime] [ONNX Runtime] SGEMM의 함정에서 벗어나기: GQA 전용 GEMV 커널을 통한 디코딩 최적화
- 다음글 [vllm] vLLM ROCm 환경에서 Shared-Expert Fusion을 통한 MoE 추론 성능 최적화
댓글