[sglang] AMD GPU에서 FP8 KV 캐시 쓰기 최적화: Triton 커널 융합으로 성능 향상
PR 링크: sgl-project/sglang#23620 상태: Merged | 변경: +None / -None
들어가며
최근 대규모 언어 모델(LLM)의 발전과 함께 GPU 하드웨어의 발전도 가속화되고 있습니다. 특히 AMD GPU는 특정 워크로드에서 뛰어난 성능을 보여주며 주목받고 있습니다. 하지만 이러한 하드웨어의 잠재력을 최대한 활용하기 위해서는 소프트웨어 최적화가 필수적입니다.
SGLang 프로젝트의 이번 Pull Request(PR)는 AMD GPU 환경에서 FP8 KV 캐시를 사용할 때 발생하는 성능 병목 현상을 해결하는 데 초점을 맞추고 있습니다. 기존에는 KV 캐시 쓰기 작업이 두 개의 별도 커널 호출을 필요로 하여 불필요한 오버헤드가 발생했습니다. 이 PR은 이러한 과정을 하나의 통합된 Triton 커널로 융합하여 성능을 향상시키는 것을 목표로 합니다.
본 글에서는 이 PR이 해결하고자 하는 문제점을 명확히 하고, 코드 변경 사항을 상세히 분석하며, 이러한 최적화가 왜 효과적인지에 대해 기술적인 관점에서 설명하겠습니다.
코드 변경 분석
이번 PR의 핵심 변경 사항은 python/sglang/srt/layers/attention/aiter_backend.py 파일 내 forward_decode 함수의 로직 수정에 있습니다. 특히 AMD GPU에서 FP8 KV 캐시(--kv-cache-dtype fp8_e4m3)와 Unified Attention이 활성화된 경우, 기존의 두 단계로 나뉘었던 KV 캐시 쓰기 과정을 단일 Triton 커널로 통합하는 로직이 추가되었습니다.
기존 로직 (Before)
기존에는 FP8 KV 캐시를 사용할 때, 디코딩 과정에서 KV 캐시를 쓰는 작업이 두 개의 독립적인 커널 호출로 이루어졌습니다. 첫 번째는 bf16 데이터를 FP8 형식으로 변환하는 float8_copy_kernel이고, 두 번째는 변환된 데이터를 페이징된 메모리에 저장하는 store_kvcache였습니다. 이 두 커널은 별도의 커널 론칭 오버헤드를 발생시켰습니다.
# 기존 로직의 개념적 표현 (실제 diff와는 약간 다를 수 있음)
# ... (이전 코드)
else:
# bf16 -> fp8 캐스트
casted_k, casted_v = float8_copy_kernel(k, v, scale_k, scale_v)
# 페이징된 KV 캐시 저장
token_to_kv_pool.set_kv_buffer(layer, out_cache_loc, casted_k, casted_v)
# ... (이후 코드)
변경된 로직 (After)
이번 PR에서는 AiterAttnBackend.forward_decode 함수 내에 새로운 조건을 추가하여, AMD GPU와 FP8 KV 캐시 사용 시 launch_reshape_and_cache_flash라는 기존 Triton 커널을 활용하도록 변경했습니다. 이 커널은 이미 SWA(Scaled-dot-product Attention with FlashAttention) 모델에서 사용되던 것으로, bf16에서 FP8으로의 데이터 타입 캐스팅과 페이징된 메모리에 데이터를 저장하는 작업을 하나의 커널 실행으로 융합(fuse)합니다.
diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py
index 27dd145860fa..e52ae698428b 100755
--- a/python/sglang/srt/layers/attention/aiter_backend.py
+++ b/python/sglang/srt/layers/attention/aiter_backend.py
@@ -2478,6 +2478,23 @@ def forward_decode(
k_scale=k_descale,
v_scale=v_descale,
)
+ elif self.use_triton_unified_attention and self.kv_cache_dtype == fp8_dtype:
+ # [PATCH] FP8 non-SWA: use launch_reshape_and_cache_flash to
+ # fuse bf16→fp8 cast + paged write in one Triton kernel,
+ # eliminating separate float8_copy + store_kvcache overhead.
+ token_to_kv_pool = forward_batch.token_to_kv_pool
+ k_cache, v_cache = token_to_kv_pool.get_kv_buffer(layer.layer_id)
+ launch_reshape_and_cache_flash(
+ k.view(-1, layer.tp_k_head_num, layer.qk_head_dim),
+ v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
+ k_cache.view(
+ -1, self.page_size, layer.tp_k_head_num, layer.qk_head_dim
+ ),
+ v_cache.view(
+ -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
+ ),
+ forward_batch.out_cache_loc,
+ )
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
변경된 코드는 다음과 같은 로직을 수행합니다:
self.use_triton_unified_attention과self.kv_cache_dtype == fp8_dtype조건이 참일 경우, 새로운 로직이 실행됩니다.forward_batch.token_to_kv_pool.get_kv_buffer를 통해 KV 캐시 버퍼를 가져옵니다.launch_reshape_and_cache_flash함수를 호출합니다. 이 함수는 입력k,v텐서를 받아, 이를 KV 캐시 버퍼에 저장하기 전에 필요한 데이터 타입 변환(bf16 to fp8)과 메모리 레이아웃 조정(reshape)을 수행하고, 최종적으로 페이징된 KV 캐시에 저장하는 모든 과정을 단일 Triton 커널 내에서 처리합니다.
이 변경을 통해 float8_copy_kernel과 store_kvcache로 나뉘었던 두 번의 커널 호출이 launch_reshape_and_cache_flash라는 단 한 번의 커널 호출로 대체됩니다. 이는 GPU 커널 론칭 오버헤드를 줄이고, 데이터 이동을 최소화하여 전반적인 성능을 향상시킵니다.
왜 이게 좋은가?
이번 PR의 최적화는 다음과 같은 이유로 매우 긍정적입니다.
-
커널 론칭 오버헤드 감소: GPU에서 커널을 실행하는 것은 상당한 오버헤드를 수반합니다. 특히 작은 연산들을 여러 개의 커널로 나누어 실행하면, 이 오버헤드가 누적되어 성능 저하의 주요 원인이 됩니다. 두 개의 커널 호출을 하나로 융합함으로써, 이러한 불필요한 오버헤드를 효과적으로 제거할 수 있습니다. 이는 특히 디코딩과 같이 반복적으로 발생하는 작은 연산들에서 큰 성능 향상을 가져올 수 있습니다.
-
데이터 이동 최소화: 기존 방식에서는 bf16 데이터를 FP8으로 캐스팅한 후, 그 결과를 다시 스토리지 커널로 전달해야 했습니다. 이 과정에서 중간 데이터가 메모리를 거치거나 레지스터 간 이동이 발생할 수 있습니다. 커널 융합을 통해 캐스팅된 FP8 데이터가 별도의 메모리 접근 없이 바로 KV 캐시에 저장될 수 있으므로, 데이터 이동 관련 오버헤드가 줄어듭니다. 이는 메모리 대역폭 사용량을 최적화하는 데 기여합니다.
-
Triton 커널의 활용: Triton은 Python과 유사한 문법으로 고성능 GPU 커널을 쉽게 작성할 수 있게 해주는 라이브러리입니다. 이미 검증된
launch_reshape_and_cache_flashTriton 커널을 재활용함으로써, 새로운 커널을 개발하는 데 드는 시간과 노력을 절약하고, 동시에 Triton의 최적화된 성능을 활용할 수 있습니다.
성능 향상 수치
PR 설명에 따르면, 이 변경은 AMD MI355X GPU에서 MiniMax-M2.5 FP8 모델(TP=4, ISL=8192, OSL=1024)을 사용한 벤치마크에서 다음과 같은 출력 처리량(output throughput) 향상을 보였습니다:
conc=64: +2.5%conc=32: +2.3%conc=4: 최대 +5.9%conc=128: +0.4% (회귀 없음)
이 수치들은 특히 배치 크기나 시퀀스 길이가 짧을 때(conc 값이 작을 때) 더 큰 성능 향상이 나타남을 시사합니다. 이는 작은 연산들의 융합이 이러한 시나리오에서 더 큰 효과를 발휘하기 때문으로 해석될 수 있습니다.
일반적인 교훈
이 PR은 다음과 같은 일반적인 최적화 교훈을 제공합니다:
- 커널 융합(Kernel Fusion): 여러 개의 작은 GPU 커널을 가능한 한 단일 커널로 융합하는 것은 성능 향상의 강력한 기법입니다. 특히 데이터 전처리, 후처리, 또는 중간 계산 단계에서 여러 단계의 연산이 연속적으로 발생하는 경우, 이를 융합하면 커널 론칭 오버헤드와 데이터 이동을 크게 줄일 수 있습니다.
- 하드웨어 특화 최적화: AMD GPU와 같이 특정 하드웨어 아키텍처의 특성을 고려한 최적화는 상당한 성능 향상을 가져올 수 있습니다. FP8과 같은 데이터 타입을 활용하고, 해당 하드웨어에 최적화된 커널(예: Triton)을 사용하는 것이 중요합니다.
- 기존 도구 활용: 새로운 기능을 구현할 때, 이미 존재하는 고성능 라이브러리나 커널(예: Triton 커널)을 재활용하는 것은 개발 효율성과 성능을 동시에 높이는 좋은 전략입니다.
리뷰 피드백 분석
제공된 PR 정보에는 구체적인 리뷰 댓글이 포함되어 있지 않습니다. 만약 리뷰어들이 이 변경에 대해 기술적인 피드백이나 우려를 제기했다면, 해당 내용을 분석하여 최적화의 타당성이나 잠재적 위험에 대한 추가적인 통찰을 제공할 수 있었을 것입니다. 예를 들어, 새로운 커널이 특정 하드웨어에서 예상치 못한 부작용을 일으키는지, 또는 FP8 정밀도 문제로 인한 정확도 저하 가능성 등에 대한 논의가 있었다면 이를 반영할 수 있습니다.
하지만 현재 정보만으로는 리뷰 피드백을 분석하기 어렵습니다. 다만, PR 설명에 'Accuracy Tests' 섹션에서 GSM8K 정확도가 93.3%로 베이스라인과 동일하다는 점을 명시한 것으로 보아, 정확도 저하에 대한 우려를 사전에 인지하고 검증했음을 알 수 있습니다. 이는 매우 중요한 부분이며, 성능 최적화 과정에서 정확도를 유지하는 것이 얼마나 중요한지를 보여줍니다.
결론
SGLang의 이번 PR은 AMD GPU 환경에서 FP8 KV 캐시를 사용하는 LLM 추론 성능을 크게 향상시키는 중요한 개선을 이루었습니다. 기존의 두 단계 커널 호출을 단일 Triton 커널로 융합함으로써 커널 론칭 오버헤드와 데이터 이동을 줄였고, 이를 통해 최대 5.9%의 출력 처리량 향상을 달성했습니다. 이는 LLM 추론의 효율성을 높이는 데 크게 기여할 것이며, 향후 더 많은 하드웨어 및 모델에 대한 최적화의 좋은 사례가 될 것입니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang NIXL 이기종 TP 환경에서 디스어그리게이션 KV 캐시 전송 버그 수정 및 성능 개선
- [sglang] SGLang 성능 최적화: torch.cuda.empty_cache() 호출 제어를 통한 가중치 업데이트 병목 해결
- [sglang] SGLang MoE 라우팅 최적화: AMD GPU에서 aiter.biased_grouped_topk 활용
- [sglang] SGLang Triton 커널 최적화: libdevice.tanh 도입과 2D Strided Tensor 지원
- [sglang] SGLang, Diffusion 모델의 RL 기반 후처리 최적화를 위한 새로운 Rollout API 및 정밀도 개선
PR Analysis 의 다른글
- 이전글 [flashinfer] FlashInfer 오토튜너 최적화: 하이브리드 토큰 버킷 도입
- 현재글 : [sglang] AMD GPU에서 FP8 KV 캐시 쓰기 최적화: Triton 커널 융합으로 성능 향상
- 다음글 [sglang] SGLang MoE 라우팅 최적화: AMD GPU에서 aiter.biased_grouped_topk 활용
댓글