[onnxruntime] ONNX Runtime의 CPU GQA 최적화: Flash Attention과 Flash Decoding 도입
PR 링크: microsoft/onnxruntime#28695 상태: Merged | 변경: +2762 / -107
들어가며
최근 LLM 추론 성능 최적화의 핵심은 메모리 대역폭과 연산 효율성입니다. 특히 CPU 환경에서 긴 시퀀스를 처리할 때, 기존의 Naive한 Attention 구현은 전체 Attention 행렬을 메모리에 할당해야 하므로 O(S×T)의 메모리 병목을 유발합니다. 이번 Microsoft의 ONNX Runtime PR은 CPU 환경에서 INT8/INT4 양자화된 KV 캐시를 위해 Flash Attention 스타일의 타일링 연산과 Flash Decoding을 도입하여, 메모리 사용량을 획기적으로 줄이고 추론 속도를 개선했습니다.
코드 분석
1. Flash Attention 타일링 (onnxruntime/core/mlas/lib/flashattn_qkv.cpp)
핵심 변경 사항은 전체 행렬을 메모리에 올리는 대신, L2 캐시 크기에 맞춘 블록 단위로 연산을 수행하는 것입니다. 온라인 소프트맥스(Online Softmax)를 사용하여 중간 결과값을 누적함으로써 메모리 사용량을 O(S×Bc)로 최적화했습니다.
// Before: Naive path allocates full [S, T] matrix
// After: Tiled computation with online softmax
for (size_t i = 0; i < kv_chunk_count; ++i) {
MlasQKGemm(..., kv_block_slice, ...);
// Online Softmax: Track running max and sum
UpdateSoftmaxStatistics(m, l, scores);
// Fused SV accumulation: Dequantize V on the fly
MlasSVGemm(..., Beta=1.0, ...);
}
2. Flash Decoding (onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h)
디코딩 단계(S=1)에서 배치 크기가 작을 경우, 유휴 상태인 CPU 스레드를 활용하여 KV 시퀀스를 병렬로 분할 처리합니다. 각 스레드가 부분적인 소프트맥스 통계를 계산하고, 마지막에 이를 리듀스(Reduce)하는 방식을 취합니다.
// Flash Decoding logic for S=1
if (batch_size * num_heads < thread_count && kv_chunk_count > 1) {
// Parallel KV scan across idle threads
ParallelFor(kv_chunk_count, [&](size_t chunk_idx) {
ComputePartialSoftmax(chunk_idx, partials[chunk_idx]);
});
// Reduce phase: Merge partials using log-sum-exp trick
ReducePartials(partials);
}
왜 이게 좋은가
- 메모리 효율성: 4096 시퀀스 길이에서 기존 대비 13~24배의 메모리 절감 효과를 보입니다. 이는 대규모 모델을 제한된 CPU 메모리에서 구동할 때 결정적인 차이를 만듭니다.
- 연산 속도: Prefill 단계에서 1.2
2.7배, Decode 단계에서 25배의 속도 향상을 달성했습니다. 특히 Flash Decoding은 스레드 활용도를 극대화하여 긴 시퀀스에서의 지연 시간을 크게 줄였습니다. - 일반적 교훈: 메모리 접근 패턴(Memory Access Pattern)을 캐시 친화적인 블록 단위로 재구성하고, 연산과 메모리 할당을 퓨전(Fusion)하는 것이 CPU 성능 최적화의 핵심임을 보여줍니다.
리뷰어 피드백 반영
- Ragged Seqlens 처리: 배치 내 시퀀스 길이가 다를 경우(Ragged), 기존의 통합 커널이 잘못된 메모리 영역을 참조할 수 있는 문제를 발견했습니다. 이를 위해
min_total_seqlen != max_total_seqlen일 경우 per-batch 호출로 폴백하도록 수정했습니다. - 수치 안정성: SIMD 연산 시 발생할 수 있는 부동소수점 오차를 고려하여 테스트 케이스의 허용 오차(Tolerance)를 1e-4로 완화하고, 마스킹된 행의 처리 로직을 보강하여 정확도를 확보했습니다.
참고 자료
- Flash Attention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- ONNX Runtime GQA Documentation
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://github.com/microsoft/onnxruntime/blob/main/docs/contrib_ops/cpu/gqa.md
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [flashinfer] FlashInfer MLA 커널 최적화: num_heads < 128 환경에서의 성능 극대화
- 현재글 : [onnxruntime] ONNX Runtime의 CPU GQA 최적화: Flash Attention과 Flash Decoding 도입
- 다음글 [vllm] AMD RDNA3 (gfx1100)를 위한 vLLM의 W4A16 GPTQ 커널 최적화 심층 분석
댓글