[sglang] SGLang의 MLA KV 캐시 쓰기 최적화: TMA Bulk-Store 도입
PR 링크: sgl-project/sglang#25311 상태: Merged | 변경: +0 / -0
들어가며
LLM 추론 엔진에서 MLA(Multi-Head Latent Attention) 구조의 KV 캐시를 관리하는 것은 성능의 병목이 되기 쉽습니다. 특히 set_mla_kv_buffer와 같은 paged-KV scatter-write 작업은 배치 사이즈가 커질수록 기존 Triton 커널의 오버헤드가 선형적으로 증가하는 문제가 있었습니다. 본 PR은 GB300(Blackwell) 아키텍처의 TMA(Tensor Memory Accelerator) 기능을 활용하여, 대규모 배치 환경에서 성능을 획기적으로 개선한 사례를 다룹니다.
코드 분석
1. JIT TMA Bulk-Store 커널 도입 (set_mla_kv_buffer.cuh)
기존의 1D Triton 커널은 각 요소별로 벡터 스토어를 수행했으나, 새로운 JIT 커널은 cp.async.bulk.global.shared::cta를 사용하여 행 단위로 데이터를 전송합니다.
// Lane 0 issues one bulk store from the smem slot to the scattered gmem row.
if (threadIdx.x % kWarpThreads == 0) {
cuda::ptx::cp_async_bulk(
cuda::ptx::space_global,
cuda::ptx::space_shared,
gmem_dst,
&smem[warp_in_cta][0],
static_cast<uint32_t>(kRowBytes));
}
여기서 중요한 점은 fence.proxy.async.shared::cta를 사용하여 공유 메모리(smem)와 TMA 엔진 간의 데이터 일관성을 보장하는 것입니다. TMA는 비동기 프록시를 통해 smem을 읽기 때문에, 이 펜스가 없으면 대규모 배치에서 stale한 데이터를 읽는 문제가 발생할 수 있습니다.
2. Dispatcher 최적화 (utils.py)
배치 사이즈에 따라 최적의 커널을 선택하는 디스패처를 도입했습니다. 768 이상의 배치 사이즈에서는 TMA를 사용하고, 그 미만에서는 BLOCK 크기를 최적화한 Triton 커널을 사용하여 오버헤드를 최소화했습니다.
if n_loc >= 768 and is_arch_support_pdl() and can_use_set_mla_kv_buffer(...):
# TMA bulk-store: packs 4-8 items per CTA
jit_set_mla_kv_buffer(...)
else:
# Triton single-CTA-per-loc with BLOCK = next_pow2(total_dim)
BLOCK = triton.next_power_of_2(nope_dim + rope_dim)
set_mla_kv_buffer_kernel[(n_loc, 1)](..., BLOCK=BLOCK, ...)
왜 이게 좋은가
이 최적화의 핵심은 CTA(Cooperative Thread Array) 팬아웃 감소와 메모리 대역폭 효율화입니다.
- 성능 수치: 배치 사이즈 4096에서 기존 21.6 µs 대비 1.76 µs로 약 12배의 성능 향상을 기록했습니다.
- 일반적 교훈:
- 하드웨어 가속기 활용: 최신 아키텍처(Blackwell 등)에서 제공하는 TMA와 같은 특수 명령어를 활용하면 메모리 복사 오버헤드를 극적으로 줄일 수 있습니다.
- 디스패처의 중요성: 모든 상황에 하나의 커널을 적용하기보다, 데이터 크기(배치 사이즈)에 따라 최적의 커널을 선택하는 전략이 실무 성능 최적화에 필수적입니다.
- 비동기 메모리 관리:
cp.async계열 명령어를 사용할 때는fence와wait_group을 정확히 사용하여 데이터 일관성과 동기화 지점을 명확히 해야 합니다.
리뷰 피드백 반영
리뷰어들은 GB300 아키텍처가 32바이트 벡터 메모리 액세스를 지원한다는 점을 지적하며, Lane 0에서만 수행되는 주소 계산을 최적화하여 중복 작업을 줄일 것을 제안했습니다. 이러한 세밀한 최적화가 커널의 최종 성능을 결정짓는 핵심 요소가 되었습니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang Triton 커널 최적화: libdevice.tanh 도입과 2D Strided Tensor 지원
- [sglang] SGLang의 디코드 성능 향상을 위한 Temperature 및 Softmax 커널 융합
- [sglang] SGLang의 MHC 파이프라인 최적화: 커널 퓨전과 DeepGemm 도입
- [sglang] SGLang 성능 최적화: torch.cuda.empty_cache() 호출 제어를 통한 가중치 업데이트 병목 해결
- [sglang] SGLang의 FA3 디코드 최적화: get_scheduler_metadata 도입
PR Analysis 의 다른글
- 이전글 [sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서
- 현재글 : [sglang] SGLang의 MLA KV 캐시 쓰기 최적화: TMA Bulk-Store 도입
- 다음글 [vllm] vLLM의 혁신: Breakable CUDA Graph로 LLM 추론 성능 최적화
댓글