[onnxruntime] WebGPU FlashAttention 최적화: 커널 퓨전과 가변 시퀀스 길이 지원으로 성능 극대화
PR 링크: microsoft/onnxruntime#28389 상태: Merged | 변경: +426 / -426
들어가며
최근 AI 모델, 특히 대규모 언어 모델(LLM)의 발전은 GPU 컴퓨팅 성능에 대한 요구를 끊임없이 증대시키고 있습니다. 트랜스포머 아키텍처의 핵심인 어텐션 메커니즘은 이러한 성능 요구의 중심에 있으며, FlashAttention과 같은 최적화된 구현은 LLM의 추론 속도를 크게 향상시키는 데 기여하고 있습니다. Microsoft의 ONNX Runtime은 WebGPU 백엔드를 통해 다양한 환경에서 AI 모델을 효율적으로 실행할 수 있도록 지원하는데, 이번 PR은 WebGPU 환경에서 FlashAttention의 디코딩 경로를 최적화하여 성능을 극대화하는 것을 목표로 합니다.
이 PR은 크게 세 가지 주요 개선 사항을 포함합니다:
- 커널 퓨전 (Kernel Fusion): 기존의 여러 단계로 나뉘어 있던 어텐션 계산 커널들을 하나의 통합된 커널로 합쳐 디스패치 횟수를 줄이고 중간 텐서 생성을 제거합니다.
- 시퀀스 길이 확장: 기존에는 고정된 시퀀스 길이(seq_len=1)만 지원하던 디코딩 경로를 모든 시퀀스 길이에 대해 동작하도록 확장하고, KV 캐시 로딩을 효율화합니다.
- 경로 라우팅 최적화: 모델의 시퀀스 길이에 따라 최적의 어텐션 계산 경로(Prefill vs. Split-Reduce)를 선택하는 휴리스틱을 개선하여 성능을 향상시킵니다.
이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 이러한 최적화가 왜 효과적인지, 그리고 어떤 기술적 교훈을 얻을 수 있는지 살펴보겠습니다.
코드 분석
1. 커널 퓨전: FlashAttentionDecodeQKTProgram + FlashAttentionDecodeSplitVxProgram → FlashAttentionDecodeQKVProgram
가장 핵심적인 변경 중 하나는 기존의 FlashAttentionDecodeQKTProgram과 FlashAttentionDecodeSplitVxProgram 두 개의 커널을 FlashAttentionDecodeQKVProgram이라는 단일 커널로 통합한 것입니다. 이 통합은 어텐션 계산의 중간 결과물인 qk 텐서(B×H×seq×present_seq 크기)를 생성하고 저장하는 과정을 제거하여 메모리 사용량을 줄이고, GPU 커널 디스패치 횟수를 3회에서 2회로 감소시킵니다.
Before (Conceptual):
// 기존 로직 (개념적)
// 1. QKT 계산 및 Softmax (FlashAttentionDecodeQKTProgram)
// 2. V와 곱셈 및 결과 분할 (FlashAttentionDecodeSplitVxProgram)
// 3. 최종 결과 취합 (VxReduce)
After (Fused Kernel):
// 변경 후 로직 (FlashAttentionDecodeQKVProgram)
// QK^T 계산, Softmax, V 곱셈까지 한 커널에서 처리
// 중간 qk 텐서 제거
// 최종 결과는 out_split_vx로 출력
FlashAttentionDecodeQKVProgram은 이제 QK^T 계산, 어텐션 바이어스 및 인과 마스크 적용, 로컬 Softmax 계산, 정규화, 그리고 V와의 곱셈까지 한 커널 내에서 수행합니다. 또한, 온라인 Softmax를 위한 메타데이터(로컬 최대값 및 합계)를 생성하여 다음 단계인 VxReduce 커널에서 최종 정규화를 수행하도록 합니다.
WGSL 코드 예시 (퓨전된 커널의 일부):
// flash_attention_decode_qkv.wgsl.template (일부)
// ... QK^T 계산 및 Softmax 로직 ...
// 로컬 Softmax 메타데이터 생성
let local_max = ...;
let local_sum = ...;
// V와 곱셈 및 결과 분할 (out_split_vx로 출력)
var out_split_vx = ...;
// 메타데이터 출력 (VxReduce에서 사용)
metadata[global_idx] = vec2<f32>(local_max, local_sum);
리뷰어 hariharans29는 이 온라인 Softmax 재스케일링 수학이 정확함을 확인했습니다. 기존의 partial_i (로컬 정규화된 값)에 VxReduce 단계에서 rescale_i를 곱함으로써 최종적으로 표준 Softmax 가중치 V와 동일한 결과를 얻게 됩니다.
2. 시퀀스 길이 확장 및 m_tile 최적화
이전 디코딩 경로는 sequence_length == 1로 고정되어 있었습니다. 이번 PR에서는 이를 모든 시퀀스 길이에 대해 동작하도록 확장했습니다. 특히, m_tile 파라미터를 도입하여 각 워크그룹이 여러 개의 Q 행(1, 2, 또는 4개)을 처리하도록 하여 KV 로딩의 오버헤드를 분산시킵니다.
코드 변경 (C++):
@@ -16,6 +16,40 @@ namespace onnxruntime {
namespace contrib {
namespace webgpu {
+// WGSL helper function for normalizing on-device indirect dispatch dims.
+// Shared by CopyKVCacheProgram and SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram.
+// Mirrors ProgramManager::NormalizeDispatchGroupSize three tiers:
+// 1) direct (x, y, z) write when every dim is within the spec limit (65535);
+// 2) 2D sqrt collapse when the product fits a square layout;
+// 3) 3D cbrt collapse otherwise.
+// Consumers are unaffected by the chosen layout: ShaderHelper flattens
+// workgroup_id (x, y, z) into a single linear workgroup_idx.
+// Caller contract: must register a storage output named exactly
+// `indirect_buffer` of array<u32> with at least 3 elements.
+constexpr const char kNormalizeDispatchGroupSizeFn[] = R"(
+fn normalize_dispatch_group_size(x: u32, y: u32, z: u32) {
+ let limit = 65535u; // WebGPU spec maxComputeWorkgroupsPerDimension
+ if (x <= limit && y <= limit && z <= limit) {
+ indirect_buffer[0] = x;
+ indirect_buffer[1] = y;
+ indirect_buffer[2] = z;
+ return;
+ }
+ let size = f32(x) * f32(y) * f32(z);
+ let dispatch_avg_2d = u32(ceil(sqrt(size)));
+ if (dispatch_avg_2d <= limit) {
+ indirect_buffer[0] = dispatch_avg_2d;
+ indirect_buffer[1] = dispatch_avg_2d;
+ indirect_buffer[2] = 1u;
+ return;
+ }
+ let dispatch_avg_3d = u32(ceil(pow(size, 1.0 / 3.0)));
+ indirect_buffer[0] = dispatch_avg_3d;
+ indirect_buffer[1] = dispatch_avg_3d;
+ indirect_buffer[2] = dispatch_avg_3d;
+}
+)";
+
Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform);
const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform);
@@ -28,6 +62,7 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha
if (prepare_indirect_dispatch_) {
sh.AddOutput("indirect_buffer", ShaderUsage::None);
+ sh.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn;
}
return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template",
kNormalizeDispatchGroupSizeFn는 WebGPU의 최대 워크그룹 차원 제한(65535)을 초과하지 않도록 디스패치 차원을 조정하는 헬퍼 함수입니다. 이는 CopyKVCacheProgram과 SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram 모두에서 사용됩니다.
또한, FlashAttentionDecodeQKVProgram 생성 시 m_tile_ 파라미터가 전달되어 m_tile 값으로 사용됩니다. 이는 q_BNSH (Query, Key, Value, Bias, SequenceLength, Head) 포맷과 is_unidirectional (단방향 어텐션 여부) 플래그와 함께 셰이더 동작을 결정합니다.
// FlashAttentionDecodeQKVProgram 생성 시 m_tile 전달
FlashAttentionDecodeQKVProgram program{
"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile
};
3. 경로 라우팅 최적화
이전에는 (sequence_length < 4) || (sequence_length < 32 && total_sequence_length > 1000)과 같은 복잡한 휴리스틱으로 Prefill 경로와 Split-Reduce 경로를 선택했습니다. 하지만 그래프 캡처 시 total_sequence_length_가 0이 되는 경우가 많아 해당 조건이 제대로 동작하지 않는 문제가 있었습니다.
이번 PR에서는 이 휴리스틱을 sequence_length < 32로 단순화했습니다. 이는 Copilot의 리뷰 코멘트에서 지적된 그래프 캡처 시 total_sequence_length_ == 0 문제를 해결하고, PR 설명에 명시된 것처럼 실제 벤치마크 결과에 기반하여 Split-Reduce 경로가 짧은 시퀀스 길이에서 일관되게 더 빠르다는 점을 활용한 것입니다.
코드 변경 (C++):
@@ -224,52 +258,66 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
WGSL_TEMPLATE_PARAMETER(use_shm_path, use_shm_path_));
}
-Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const {
- shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
- shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
+Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const {
+ const auto& q = shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
+ const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
+ const auto& present_value = shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
if (use_indirect_dispatch_) {
shader.AddInput("seqlens_k", ShaderUsage::None);
}
if (has_attention_bias_) {
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
}
- shader.AddOutput("output", ShaderUsage::UseUniform);
- shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
+ const auto& out_split_vx = shader.AddOutput("out_split_vx", ShaderUsage::UseUniform);
+ const auto& metadata = shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const uint32_t tile_size_k_vec = 8;
const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec;
- return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkt.wgsl.template",
+ return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkv.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_),
+ WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_),
+ WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_),
+ WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_),
WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count),
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec),
- WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_));
+ WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_),
+ WGSL_TEMPLATE_PARAMETER(v_head_size_vec, head_size_vec_),
+ WGSL_TEMPLATE_VARIABLE(metadata, metadata),
+ WGSL_TEMPLATE_VARIABLE(out_split_vx, out_split_vx),
+ WGSL_TEMPLATE_VARIABLE(present_key, present_key),
+ WGSL_TEMPLATE_VARIABLE(present_value, present_value),
+ WGSL_TEMPLATE_VARIABLE(q, q));
}
-Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q,
- const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata, const Tensor* seqlen_k,
- const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length) {
+Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q,
+ const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value,
+ Tensor* metadata, const Tensor* seqlen_k,
+ const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile) {
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size_))
: parameters.scale_;
const bool has_attention_bias = attention_bias != nullptr;
const int components = 4;
+ const int head_size_vec = parameters.v_head_size_ / components;
- FlashAttentionDecodeQKTProgram program{"FlashAttentionDecodeQKT", has_attention_bias, tile_size, use_indirect_dispatch};
+ bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH;
+ bool is_unidirectional = parameters.is_unidirectional_;
+ FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
- {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}});
+ {present_key, ProgramTensorMetadataDependency::TypeAndRank, components},
+ {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (use_indirect_dispatch) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None});
}
if (has_attention_bias) {
program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank});
}
- program.AddOutputs({{output, ProgramTensorMetadataDependency::Rank},
+ program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components},
{metadata, ProgramTensorMetadataDependency::Rank, 2}});
const uint32_t vectorized_head_size = parameters.head_size_ / components;
- // Get attention bias dimensions for broadcasting
uint32_t attn_bias_dim0 = 1;
uint32_t attn_bias_dim1 = 1;
if (has_attention_bias) {
@@ -281,10 +329,10 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
if (use_indirect_dispatch) {
program.SetIndirectDispatchTensor(indirect_buffer);
} else {
- program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_total_seq_length_tile);
+ program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_total_seq_length_tile);
}
program.SetWorkgroupSize(64)
- .CacheHint(tile_size, has_attention_bias, use_indirect_dispatch)
+ .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile)
.AddUniformVariables({{static_cast<uint32_t>(vectorized
Copilot은 이 라우팅 휴리스틱이 그래프 캡처 시 total_sequence_length_ == 0 문제로 인해 잘못 동작할 수 있음을 지적했으나, qjia7의 커밋 13ea0e4c58에서 parameters.sequence_length_ < 32로 수정되어 이 문제가 해결되었습니다. 또한, PR 설명에 명시된 벤치마크 결과(sequence_length ∈ {16, 30, 31} × total_sequence_length ∈ {128, 500, 2000} 범위에서 Split-Reduce가 1.13×–2.07× 빠름)를 통해 이 단순화된 휴리스틱이 타당함을 뒷받침합니다.
4. 기타 개선 사항 및 리뷰 피드백
indirect_buffer관리:kNormalizeDispatchGroupSizeFn함수는 WebGPU의maxComputeWorkgroupsPerDimension제한을 준수하도록 디스패치 크기를 조정합니다. 이는CopyKVCacheProgram과SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram모두에서 사용됩니다.- 메모리 사용량: 커널 퓨전으로
qk텐서가 제거되었지만,out_split_vx버퍼의 크기가B × H × seq × present_seq에서B × H × seq × num_present_sequence_length_tile × v_head_size로 증가했습니다. 리뷰어hariharans29는 이로 인해 메모리 사용량이 약간 증가할 수 있지만(예: seq=31, present_seq=8K, fp16 기준 약 34MB 증가), 디스패치 횟수 감소 및 대역폭 병목 현상 완화를 고려할 때 합리적인 절충이라고 평가했습니다. use_seqlen_k처리: Copilot은use_seqlen_k사용 시seqlens_k와 디스패치 크기 간의 불일치로 인한 데이터 경쟁 가능성을 지적했습니다.qjia7의 커밋b690d0f37b에서use_seqlen_k플래그가use_indirect_dispatch로 변경되고, GPU 측에서seqlens_k[0]으로부터 직접 디스패치 크기를 계산하도록 수정하여 이 문제가 해결되었습니다.m_tile선택:m_tile값의 선택 로직은 PR diff에서 명확히 보이지 않지만,#param(specialization constant)으로 처리되어 런타임 시 일관성을 유지한다고 가정합니다. (리뷰어hariharans29의 관찰)cpplint경고:flash_attention.h파일에서#include <string>누락이 지적되었고, 이는 병합 전에 수정되었습니다.
왜 이게 좋은가
이 PR은 여러 측면에서 뛰어난 최적화 및 개선을 보여줍니다:
- 성능 향상: 가장 주목할 만한 결과는 Whisper 디코딩의 Prefill 성능이 4.68ms에서 1.09ms로 약 4.3배 향상된 것입니다. 이는 커널 퓨전,
m_tile을 통한 KV 로딩 효율화, 그리고 더 적합한 경로 선택 휴리스틱 덕분입니다. - 메모리 효율성: 중간
qk텐서 제거는 메모리 사용량을 줄여 더 큰 배치 크기나 더 긴 시퀀스를 처리할 수 있는 여지를 제공합니다. 비록out_split_vx버퍼의 크기가 증가했지만, 전체적인 메모리 대역폭 사용량 감소와 디스패치 횟수 감소 효과가 더 클 수 있습니다. - 일반화 및 유연성: 기존에 고정된 시퀀스 길이만 지원하던 디코딩 경로를 모든 시퀀스 길이에 대해 동작하도록 확장하여 모델 적용 범위를 넓혔습니다. 또한,
use_seqlen_k관련 문제를 해결하고 간결한 라우팅 휴리스틱을 도입하여 코드의 견고성과 유지보수성을 높였습니다. - 기술적 교훈:
- 커널 퓨전의 힘: 여러 단계의 연산을 하나로 합치는 것은 GPU에서 상당한 성능 향상을 가져올 수 있습니다. 중간 데이터의 메모리 저장 및 로딩 오버헤드를 제거하고, 레지스터 활용도를 높일 수 있습니다.
- 워크로드 특성에 맞는 경로 선택: 모든 연산이 모든 워크로드에 최적인 것은 아닙니다. 모델의 특성(시퀀스 길이, 배치 크기 등)에 따라 최적의 실행 경로를 동적으로 선택하는 것은 성능 최적화의 핵심입니다. 벤치마크 기반의 휴리스틱 개선은 이러한 최적화를 뒷받침합니다.
- 온라인 알고리즘의 활용: 온라인 Softmax와 같이 중간 결과를 활용하여 최종 결과를 계산하는 기법은 메모리 사용량을 줄이면서도 정확성을 유지하는 데 효과적입니다.
- 간결함과 견고성의 균형: 복잡한 조건문이나 예외 처리는 코드를 이해하기 어렵게 만들고 버그를 유발할 수 있습니다. PR은 벤치마크 결과를 바탕으로 불필요한 조건을 제거하고 코드를 단순화하여 견고성을 높였습니다.
결론
이번 ONNX Runtime의 WebGPU FlashAttention 최적화 PR은 커널 퓨전, 시퀀스 길이 확장, 그리고 경로 라우팅 개선을 통해 LLM 추론 성능을 크게 향상시키는 성공적인 사례입니다. 특히 Whisper 디코딩 성능의 획기적인 개선은 이러한 최적화의 실질적인 효과를 보여줍니다. 이 PR은 GPU 컴퓨팅 최적화의 중요성과 함께, 워크로드 특성을 고려한 알고리즘 설계 및 구현의 가치를 다시 한번 강조합니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://onnxruntime.ai/docs/reference/webgpu.html
- https://github.com/microsoft/onnxruntime/blob/main/docs/ONNX_Runtime_WebGPU.md
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [onnxruntime] ONNX Runtime CUDA MoE: 소규모 배치 디코딩을 위한 SoftmaxTopK 라우터 최적화
- [sglang] 실시간 RGB 전송 속도 향상을 위한 최적화 분석
- [axolotl] Axolotl MoE 모델 최적화: Tiled-MLP 도입 및 FSDP2 통합으로 성능 극대화
- [onnxruntime] ONNX Runtime CUTLASS FMHA: BiasLoader 정렬 문제 해결로 안정성 및 호환성 향상
- [onnxruntime] Apple M4 Max를 위한 FlashAttention 최적화: 20배 성능 향상 분석
PR Analysis 의 다른글
- 이전글 [sglang] SGLang Diffusion 모델의 FP8 GEMM 최적화: 41.5% 성능 향상 달성
- 현재글 : [onnxruntime] WebGPU FlashAttention 최적화: 커널 퓨전과 가변 시퀀스 길이 지원으로 성능 극대화
- 다음글 [onnxruntime] ONNX Runtime CUDA MoE: 소규모 배치 디코딩을 위한 SoftmaxTopK 라우터 최적화
댓글