[onnxruntime] Apple M4 Max를 위한 FlashAttention 최적화: 20배 성능 향상 분석
PR 링크: microsoft/onnxruntime#27780 상태: Merged | 변경: +0 / -0
들어가며
최근 microsoft/onnxruntime 레포지토리에 흥미로운 PR이 머지되었습니다. Apple M4 Max 환경에서 MultiHeadAttention 연산의 속도를 기존 대비 약 20배 향상시킨 최적화입니다. 본 글에서는 이 PR이 어떤 기술적 의사결정을 통해 성능 병목을 해결했는지, 그리고 왜 특정 하드웨어에서 성능 회귀(regression)가 발생했는지 분석합니다.
코드 분석
1. flash_attention.h: 하드웨어별 동적 파라미터 설정
기존에는 max_k_step이 16으로 고정되어 있었으나, Apple GPU의 공유 메모리(Shared Memory) 버짓을 고려하여 동적으로 계산하도록 변경되었습니다.
// Before
max_k_step_ = 16;
// After
if (is_apple) {
const int element_size = is_fp16 ? 2 : 4;
int max_k_from_shm = 16384 / (2 * element_size * qkv_head_size);
max_k_step_ = (max_k_from_shm >= 32) ? 32 : 16;
} else {
max_k_step_ = 16;
}
2. flash_attention.wgsl.template: Apple 전용 커널 경로 추가
Apple GPU의 아키텍처 특성을 활용하기 위해 별도의 WGSL 템플릿 경로를 분기했습니다. 특히 is_apple 플래그를 통해 Apple 하드웨어에서만 최적화된 loadk, loadv 함수를 사용하도록 설계했습니다.
#if is_apple
fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32) {
// Apple GPU에 최적화된 루프 기반 로드 로직
for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) {
k_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq);
}
}
#else
// 기존 범용 로직
fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32) { ... }
#endif
왜 이게 좋은가
성능 수치
- LightOnOCR-2-1B-ONNX (Vision Encoder): 58.3s → 2.89s (약 20배 향상)
- all-MiniLM-L6-v2: 4.36ms → 2.02ms (약 2.16배 향상)
핵심 교훈
- 레지스터 압박(Register Pressure) 관리:
max_k_step을 무조건 높이는 것은 레지스터 스필링(spilling)을 유발하여 오히려 성능을 저하시킵니다. 리뷰 과정에서 64 대신 32를 선택함으로써 성능 유지와 안정성을 모두 확보했습니다. - 하드웨어별 분기 전략: 범용적인 최적화가 모든 기기에서 통하지 않음을 확인했습니다. 특히 Qualcomm 기기에서 발생한 성능 회귀를 해결하기 위해 하드웨어별로 경로를 분기하는 전략이 필수적이었습니다.
- 공유 메모리 활용: Apple GPU의 공유 메모리 버짓을 계산하여
max_k_step을 동적으로 결정한 점은 하드웨어 가속기 최적화의 정석을 보여줍니다.
리뷰어 피드백 분석
리뷰 과정에서 qjia7은 Qualcomm 기기에서의 성능 회귀를 지적했고, xenova는 이를 수용하여 Qualcomm 경로를 복구했습니다. 또한, subgroupShuffle 사용 시 Apple 하드웨어에서 발생하는 예기치 못한 이슈를 발견하여, 무조건적인 최적화보다는 하드웨어별 특성을 고려한 튜닝이 중요함을 다시 한번 입증했습니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [onnxruntime] [ONNX Runtime] PagedAttention의 FA 경로 최적화 및 정확성 개선
- [sglang] SGLang: Triton 버전 업그레이드에 따른 MoE 성능 회귀 해결 및 설정 자동화
- [vllm] vLLM의 분산 추론 성능 극대화: 양방향 KV 캐시 전송을 통한 Prefill 최적화
- [onnxruntime] ONNX Runtime의 RISC-V Vector(RVV) 최적화: SGEMM과 Softmax 성능을 3배로 끌어올리기
- [sglang] SGLang에서 GLM-5 모델 성능 최적화: Aiter 백엔드 활용 및 텐서 패딩 전략
PR Analysis 의 다른글
- 이전글 [vllm] Blackwell을 위한 새로운 MLA 백엔드: TOKENSPEED_MLA 분석 (DeepSeek R1 최적화)
- 현재글 : [onnxruntime] Apple M4 Max를 위한 FlashAttention 최적화: 20배 성능 향상 분석
- 다음글 [vllm] vLLM의 NIXL KV 전송을 활용한 GDN(Gated Delta Net) 모델 지원 최적화
댓글