본문으로 건너뛰기

[onnxruntime] [ONNX Runtime] PagedAttention의 FA 경로 최적화 및 정확성 개선

PR 링크: microsoft/onnxruntime#28409 상태: Merged | 변경: +0 / -0

들어가며

ONNX Runtime의 PagedAttention 구현체에서 Flash Attention(FA) 경로가 사용하는 max_query_len 계산 방식에 문제가 있었습니다. 기존에는 token_count - batch_size + 1이라는 휴리스틱을 사용했는데, 이는 모든 배치가 최소 1개 이상의 새로운 토큰을 가진다는 가정하에 설계되었습니다. 하지만 실제 워크로드에서는 특정 배치가 0개의 새로운 토큰을 가질 수 있으며, 이로 인해 Rotary Embedding 커널의 토큰 누락, FA 그리드 설정 오류(CUDA error 9), 그리고 성능 저하가 발생했습니다. 본 PR은 이 휴리스틱을 호스트에서 계산된 정확한 최댓값으로 대체하여 정확성과 성능을 모두 잡았습니다.

코드 분석

1. onnxruntime/contrib_ops/cuda/bert/paged_attention.cc

기존에는 MEA(Memory Efficient Attention) 경로에서만 수행하던 max_query_len 계산을 FA 경로에서도 사용할 수 있도록 로직을 통합했습니다.

// Before
const int max_query_len = token_count - batch_size + 1;

// After (Host-side computation)
for (int i = 0; i < parameters.batch_size; ++i) {
  const int q_len_i = cum_q_pinned.get()[i + 1] - cum_q_pinned.get()[i];
  if (q_len_i > max_query_len) {
    max_query_len = q_len_i;
  }
}
data.max_query_len = max_query_len;

2. onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu

이제 FlashAttention 함수는 더 이상 휴리스틱에 의존하지 않고, paged_attention.cc에서 계산되어 전달된 data.max_query_len을 직접 사용합니다.

// Before
const int max_query_len = token_count - batch_size + 1;

// After
const int max_query_len = data.max_query_len;

왜 이게 좋은가

성능 향상 및 정확성 확보

  • 정확성: 기존 휴리스틱은 배치가 0개의 토큰을 가질 경우 음수 값을 생성하여 CUDA error 9를 유발하거나, Rotary Embedding 계산에서 토큰을 누락시켰습니다. 이제 정확한 값을 사용함으로써 이러한 버그가 완전히 해결되었습니다.
  • 성능: 대규모 Prefill 워크로드(B=64)에서 기존 휴리스틱은 불필요하게 큰 그리드를 생성하여 커널 실행 오버헤드가 컸습니다. 최적화 후, 대규모 배치에서 최대 92.3%의 성능 향상을 기록했습니다.

교훈

  1. 파일 경계 간의 불변성 관리: paged_attention.cc.cu 파일 사이의 데이터 계약(Contract)을 명확히 하는 것이 중요합니다. 리뷰어 tianleiwu가 제안한 것처럼, 파일 경계를 넘나드는 값에 대해서는 ORT_ENFORCE와 같은 방어적 코드를 추가하여 향후 리팩토링 시 발생할 수 있는 회귀를 방지해야 합니다.
  2. 휴리스틱의 위험성: 성능 최적화를 위한 휴리스틱은 엣지 케이스(예: 0 토큰 배치)에서 치명적인 오류를 유발할 수 있습니다. 가능한 경우 호스트 측에서 정확한 통계치를 계산하는 것이 오버헤드보다 이득이 클 때가 많습니다.

리뷰어 피드백 반영

리뷰 과정에서 assert 사용에 대한 논의가 있었습니다. Copilot은 python -O 옵션으로 인해 assert가 무시될 수 있음을 지적하며 예외 처리를 권장했으나, 기존 코드베이스의 일관성을 유지하기 위해 assert를 유지하되, 향후 리팩토링 시 명시적인 검증 로직을 도입하기로 합의했습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글