[sglang] SGLang P/D Disaggregation: Decode-Side Radix Cache 도입으로 LLM 추론 성능 극대화
PR 링크: sgl-project/sglang#19746 상태: Merged | 변경: +None / -None
들어가며
최근 LLM(Large Language Model)의 발전 속도는 눈부시지만, 동시에 모델의 크기가 커짐에 따라 추론 시 발생하는 메모리 및 연산 비용 또한 기하급수적으로 증가하고 있습니다. 특히 대화형 애플리케이션처럼 여러 턴에 걸쳐 동일한 프롬프트를 반복적으로 처리해야 하는 경우, 이전 턴의 KV Cache를 효율적으로 재사용하는 것이 성능 향상의 핵심입니다. SGLang은 이러한 문제를 해결하기 위해 P/D (Prefill & Decode) Disaggregation 아키텍처를 도입했으며, 이번 PR은 이 아키텍처에 Decode-Side Radix Cache를 추가하여 KV Cache 재사용률을 극대화하고 추론 성능을 획기적으로 개선하는 내용을 담고 있습니다.
기존 P/D Disaggregation 방식에서는 각 턴마다 Prefill 단계에서 생성된 KV Cache를 Decode 워커로 전송해야 했습니다. 이는 특히 긴 대화 히스토리를 가진 경우 상당한 네트워크 오버헤드를 유발했습니다. 본 PR에서 제안하는 Decode-Side Radix Cache는 Decode 워커 자체적으로 공유되는 프롬프트(prefix)의 KV Cache를 효율적으로 관리하고 재사용함으로써, Prefill 워커로부터 전송받아야 하는 KV Cache의 양을 최소화합니다. 결과적으로, 요청 처리량(throughput)과 응답 속도(latency) 모두에서 상당한 성능 향상을 가져왔습니다.
이 글에서는 해당 PR의 주요 변경 사항을 코드 diff와 함께 살펴보고, 왜 이러한 변경이 성능 향상으로 이어지는지, 그리고 이 최적화가 가지는 일반적인 교훈은 무엇인지 분석해 보겠습니다.
코드 분석
이번 PR의 핵심은 Decode 워커가 Prefill 워커로부터 받는 KV Cache의 양을 줄이기 위해 자체적으로 Radix Cache를 활용하는 로직을 추가한 것입니다. 변경 사항은 주로 decode.py, common/conn.py, base/conn.py 파일에 집중되어 있습니다.
1. Decode Scheduler (decode.py)
Decode 워커의 스케줄러는 새로운 요청이 들어왔을 때, 이를 기존에 Decode 워커가 가지고 있는 Radix Cache와 매칭하는 로직을 추가했습니다. 만약 매칭되는 Prefix가 있다면, 해당 Prefix에 해당하는 KV Cache 노드를 잠그고(lock) 요청 수명 주기 동안 유지합니다. 그리고 Prefill 워커에게는 이전에 매칭된 Prefix를 제외한 '델타(delta)' KV 페이지만 요청하도록 변경되었습니다.
주요 변경점:
-
_match_prefix_and_lock함수 추가: 새로운 요청(req)이 들어왔을 때,self.tree_cache(Radix Cache)를 사용하여 기존 KV Cache와 매칭합니다. 매칭된 Prefix 노드의 Lock Reference Count를 증가시켜 해당 노드가 Eviction되지 않도록 보호합니다.@@ -448,6 +456,25 @@ def add(self, req: Req, is_retracted: bool = False) -> None: + def _match_prefix_and_lock(self, req: Req) -> Tuple[torch.Tensor, int]: + """ + Match a request against the decode-side radix cache, lock the matched + node to prevent eviction, and return the matched prefix information. + """ + result = match_prefix_for_req( + self.tree_cache, + req, + req.origin_input_ids, + cow_mamba=self.tree_cache.supports_mamba(), + include_req=True, + ) + prefix_indices = result.device_indices + last_device_node = result.last_device_node + # Always lock to match aggregated scheduling behavior + self.tree_cache.inc_lock_ref(last_device_node) + + return prefix_indices, len(prefix_indices) + def _resolve_prefill_dp_rank(self, req: Req) -> Optional[int]: def pop_preallocated( -
pop_preallocated함수 수정: 요청을 할당할 때, Radix Cache 매칭 결과를 반영하여prefix_len을 계산하고, 이prefix_len만큼은 Prefill 워커로부터 전송받지 않도록 KV 인덱스를 조정합니다.@@ -736,14 +763,42 @@ def pop_preallocated( + if self.scheduler.server_args.disaggregation_decode_enable_radix_cache: + # Match prefix against decode's radix cache. + prefix_indices, prefix_len = self._match_prefix_and_lock(decode_req.req) + # Align prefix_len down to page boundary so both prefill and + # decode agree on the page-aligned split point for KV transfer. + page_size = self.token_to_kv_pool_allocator.page_size + if page_size > 1 and prefix_len % page_size != 0: + prefix_len = page_align_floor(prefix_len, page_size) + prefix_indices = prefix_indices[:prefix_len] + + fill_len = origin_input_len + max(len(decode_req.req.output_ids) - 1, 0) + required_alloc_tokens = self._required_alloc_tokens( + fill_len=fill_len, prefix_len=prefix_len + ) + # Matching may lock previously-evictable radix pages, so refresh + # the admission budget against the post-lock pool state before we + # decide whether this request still fits. + allocatable_tokens = self._allocatable_tokens( + retractable_tokens=retractable_tokens, + count_retracted=True, + extra_reserved_reqs=len(preallocated_reqs), + ) + else: + prefix_indices = None + prefix_len = 0 + required_alloc_tokens = origin_input_len + required_tokens_for_request = ( - origin_input_len + self.num_reserved_decode_tokens + required_alloc_tokens + self.num_reserved_decode_tokens ) if ( + max( + required_tokens_for_request, + origin_input_len + - prefix_len + + min( + decode_req.req.sampling_params.max_new_tokens, + CLIP_MAX_NEW_TOKEN, + ) + ) + > allocatable_tokens + ): + if prefix_len > 0: + self.tree_cache.dec_lock_ref(decode_req.req.last_node) + break + if required_tokens_for_request > allocatable_tokens: + if prefix_len > 0: + self.tree_cache.dec_lock_ref(decode_req.req.last_node) + break + - allocatable_tokens -= required_tokens_for_request + dst_kv_indices = self._pre_alloc(decode_req.req, prefix_indices, prefix_len) + # Recompute from actual pool state for the next queue entry. + # This accounts for page rounding and newly locked evictable cache. + allocatable_tokens = self._allocatable_tokens( + retractable_tokens=retractable_tokens, + count_retracted=True, + extra_reserved_reqs=len(preallocated_reqs) + 1, + ) + decode_req.req.cache_protected_len = prefix_len + - dst_kv_indices = self._pre_alloc(decode_req.req) + # Only send delta indices (beyond prefix) to prefill. + kv_indices = ( + self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][ - dst_kv_indices[:origin_input_len].cpu().numpy().astype(np.int32) + prefix_len:origin_input_len + ] + .cpu() + .numpy() + ) + page_size = self.token_to_kv_pool_allocator.page_size + + # Prepare extra pool indices for hybrid models @@ -821,7 +874,10 @@ assert decode_req.metadata_buffer_index is not None page_indices = kv_to_page_indices(kv_indices, page_size) decode_req.kv_receiver.send_metadata( - page_indices, decode_req.metadata_buffer_index, state_indices + page_indices, + decode_req.metadata_buffer_index, + state_indices, + decode_prefix_len=prefix_len, ) if ( @@ -848,7 +904,10 @@ ) def _allocatable_tokens( - self, retractable_tokens: Optional[int] = None, count_retracted: bool = True + self, + retractable_tokens: Optional[int] = None, + count_retracted: bool = True, + extra_reserved_reqs: int = 0, ) -> int: need_space_for_single_req = ( max( -
send_metadata함수 수정:decode_prefix_len인자를 추가하여 Prefill 워커에게 현재 요청의 Prefix 길이를 전달합니다. 이를 통해 Prefill 워커는 이 길이만큼은 KV Cache를 전송하지 않아도 됨을 알 수 있습니다.@@ -136,6 +142,7 @@ def send_metadata( kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None,
-
): """decode_prefix_len: Optional[int] = None, state_indices: Optional[List[int]] = None,
2. Common Connection (common/conn.py)
Prefill 워커의 커넥션 로직에서 decode_prefix_len을 받아 처리하는 부분이 추가되었습니다. 또한, Prefill 워커가 Decode 워커로부터 KV Cache를 받을 때, decode_prefix_len 만큼은 이미 Decode 워커가 가지고 있다고 간주하여 전송 대상에서 제외합니다.
-
CommonKVSender.send_metadata함수 수정:decode_prefix_len인자를 받아transfer_infos에 저장합니다.@@ -489,6 +496,12 @@ def init(self, num_kv_indices: int, aux_index: Optional[int] = None): f"CommonKVSender init with num_kv_indices: {num_kv_indices} and aux_index: {aux_index}"
- def pop_decode_prefix_len(self) -> int:
-
return 0 - def should_send_kv_chunk(self, num_pages: int, last_chunk: bool) -> bool:
-
return num_pages > 0 - def send(
self,
kv_indices: npt.NDArray[np.int32],
-
CommonKVReceiver.update_status함수 수정: 요청 실패 시req_to_decode_prefix_len에서 해당 요청 정보를 제거합니다.@@ -179,6 +185,12 @@ def check_status(self, bootstrap_room: int) -> KVPoll: return self.request_status[bootstrap_room]
- def update_status(self, bootstrap_room: int, status: KVPoll):
-
if ( -
status == KVPoll.Failed -
and self.disaggregation_mode == DisaggregationMode.PREFILL -
and hasattr(self, "req_to_decode_prefix_len") -
): -
self.req_to_decode_prefix_len.pop(bootstrap_room, None) -
if bootstrap_room not in self.request_status: -
self.request_status[bootstrap_room] = status -
else:
3. Prefill Transfer Path (common/conn.py)
Prefill 워커의 CommonKVSender는 이제 decode_prefix_len을 고려하여 실제로 전송해야 할 KV Chunk의 양을 결정합니다. 만약 decode_prefix_len이 num_pages보다 크거나 같으면, 해당 청크는 전송하지 않아도 됩니다.
-
CommonKVSender.send함수 수정:decode_prefix_len을 사용하여 실제로 전송할 KV 인덱스를 결정합니다.@@ -489,6 +496,12 @@ def init(self, num_kv_indices: int, aux_index: Optional[int] = None): f"CommonKVSender init with num_kv_indices: {num_kv_indices} and aux_index: {aux_index}"
- def pop_decode_prefix_len(self) -> int:
-
return 0 - def should_send_kv_chunk(self, num_pages: int, last_chunk: bool) -> bool:
-
return num_pages > 0 - def send(
self,
kv_indices: npt.NDArray[np.int32],
리뷰 댓글에서 ShangmingCai가 should_send_kv_chunk 로직이 NIXL 백엔드 외 다른 백엔드에서 문제가 될 수 있는지 질문했지만, ishandhanani는 현재 NIXL만 지원하며 다른 백엔드는 추후 PR에서 지원할 예정이라고 답변했습니다. 또한, decode_prefix_len 관련 로직을 스케줄러 스레드가 아닌 KV Manager의 비동기 스레드에서 처리하도록 개선하는 제안도 있었습니다.
왜 이게 좋은가?
이 PR의 핵심적인 개선점은 KV Cache의 재사용률을 극대화하여 통신 오버헤드를 줄이고, 결과적으로 전체 추론 성능을 향상시키는 것입니다.
1. 성능 향상
PR 설명에 포함된 벤치마크 결과는 이 최적화의 효과를 명확하게 보여줍니다.
- 요청 처리량 (Request throughput): 1.21 req/s에서 1.59 req/s로 1.32배 증가했습니다.
- 출력 토큰 처리량 (Output token throughput): 430 tok/s에서 566 tok/s로 1.32배 증가했습니다.
- TTFT (Time To First Token) p50: 73.2초에서 9.0초로 8.1배 감소했습니다. 이는 첫 토큰 생성까지의 시간이 획기적으로 단축되었음을 의미하며, 대화형 서비스에서 사용자 경험을 크게 향상시킬 수 있습니다.
- 요청 지연 시간 (Request latency) p50: 99.1초에서 73.4초로 1.35배 감소했습니다.
이러한 성능 향상은 Decode 워커가 자체적으로 Radix Cache를 활용하여 공유되는 Prefix KV Cache를 재사용함으로써, Prefill 워커로부터 전송받아야 하는 데이터 양이 크게 줄어들었기 때문입니다. 벤치마크 결과에서 Decode 워커의 KV Cache 사용률이 0.99에서 0.75로 감소하고, 동시에 처리 가능한 요청 수가 37개에서 104-126개로 증가한 점이 이를 뒷받침합니다.
2. 일반적인 교훈
이 PR은 LLM 추론 최적화에 있어 다음과 같은 중요한 교훈을 제공합니다:
- 분산 시스템에서의 캐싱 전략: P/D Disaggregation와 같이 분산된 환경에서는 각 컴포넌트가 어떻게 캐시를 공유하고 재사용할 것인지가 성능의 핵심입니다. Decode 워커가 자체적으로 Prefix Cache를 관리하는 것은 통신 비용을 절감하는 효과적인 방법입니다.
- 데이터 전송량 최적화: 단순히 계산량을 줄이는 것을 넘어, 노드 간 통신량을 줄이는 것이 전체 시스템 성능에 큰 영향을 미칩니다. KV Cache의 '델타(delta)'만 전송하는 방식은 이러한 최적화의 좋은 예입니다.
- 점진적 개선: 새로운 기능(Decode-Side Radix Cache)을 도입하면서도 기존 기능(NIXL 백엔드)과의 호환성을 유지하고, 점진적으로 다른 백엔드(Mooncake)로 확장하려는 접근 방식은 안정적인 시스템 개발에 중요합니다.
- 정확한 벤치마킹: 다양한 지표(처리량, TTFT, 지연 시간 등)를 종합적으로 측정하고, 특히 ITL(Inter-Token Latency)과 같은 잠재적인 성능 저하 요인도 분석하여 개선점을 찾는 것이 중요합니다. 리뷰어
cctry가 지적한 ITL 회귀 현상에 대한 논의는 이러한 과정을 잘 보여줍니다.
결론
SGLang의 P/D Disaggregation 아키텍처에 Decode-Side Radix Cache를 도입한 이 PR은 LLM 추론 성능을 획기적으로 개선하는 중요한 발걸음입니다. 공유되는 Prefix KV Cache를 Decode 워커가 자체적으로 관리하고 재사용함으로써, 통신 오버헤드를 줄이고 처리량을 높이며 응답 속도를 단축시키는 효과를 가져왔습니다. 이는 분산 LLM 추론 시스템 설계에 있어 캐싱 전략과 통신 최적화의 중요성을 다시 한번 강조합니다. 앞으로 Mooncake 백엔드 지원 등 추가적인 개선이 기대됩니다.
References
- SGLang P/D Disaggregation Overview (가상 링크, 실제 문서 확인 필요)
- Radix Cache in LLM Inference (유사 개념 설명, 실제 SGLang 구현과 다를 수 있음)
- NIXL Transfer Backend (NIXL 백엔드 구현)
참고 자료
- https://github.com/sgl-project/sglang/blob/main/docs/disaggregation.md
- https://huggingface.co/blog/radix-cache
- https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/disaggregation/nixl/conn.py
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang UnifiedRadixTree에 HiCache 프레임워크 도입: 하이브리드 모델 성능 최적화
- [sglang] SGLang, FP4 KV 캐시 도입으로 LLM 추론 성능 극대화: NVFP4 최적화 분석
- [sglang] SGLang NIXL 이기종 TP 환경에서 디스어그리게이션 KV 캐시 전송 버그 수정 및 성능 개선
- [sglang] HiSparse 도입: Sparse Attention 모델을 위한 효율적인 KV 캐시 관리
- [LlamaFactory] LlamaFactory: Qwen-VL 비디오 토큰 전처리 최적화로 450배 성능 향상 달성
PR Analysis 의 다른글
- 이전글 [vllm] vLLM, DCP A2A 어텐션 백엔드 최적화: 단일 All-to-All 콜렉티브로 성능 향상
- 현재글 : [sglang] SGLang P/D Disaggregation: Decode-Side Radix Cache 도입으로 LLM 추론 성능 극대화
- 다음글 [vllm] vLLM의 첫 추론 지연 문제 해결: forward_native 샘플러 커널 웜업 최적화
댓글