[SGLang] KV Cache Offloading: Decode 중 메모리 오프로딩
들어가며
Disaggregated Decode 서버에서 긴 시퀀스를 처리할 때, GPU 메모리에 모든 KV 캐시를 유지하기 어려울 수 있다. SGLang의 DecodeKVCacheOffloadManager는 Prefill에서 전송받은 KV 캐시를 점진적으로 CPU(Host) 메모리로 오프로딩하고, 필요 시 외부 스토리지에 백업하는 3단계 계층 구조를 구현한다.
구조도
┌───────────────────────────────────────────────┐
│ GPU (Device) │
│ ┌─────────────┐ │
│ │ KV Cache │ ──offload──► ┌─────────────┐ │
│ │ (Active) │ │ Host Memory │ │
│ └─────────────┘ │ (CPU Pool) │ │
│ └──────┬──────┘ │
│ │ │
│ backup │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ Storage │ │
│ │ (External) │ │
│ └──────────────┘ │
└───────────────────────────────────────────────┘
타임라인:
Prefill KV 수신 → Decode 시작 → 점진적 offload → GPU free → backup → host free
핵심 코드 분석
초기화: Host 메모리 풀 생성
python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py에서 KV 캐시 타입에 따라 Host 메모리 풀을 생성한다.
class DecodeKVCacheOffloadManager:
def __init__(self, req_to_token_pool, token_to_kv_pool_allocator,
tp_group, tree_cache, server_args):
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
if isinstance(kv_cache, MHATokenToKVPool):
self.decode_host_mem_pool = MHATokenToKVPoolHost(
kv_cache, server_args.hicache_ratio,
server_args.hicache_size, self.page_size, ...)
elif isinstance(kv_cache, MLATokenToKVPool):
self.decode_host_mem_pool = MLATokenToKVPoolHost(
kv_cache, server_args.hicache_ratio, ...)
MHA(Multi-Head Attention)와 MLA(Multi-Latent Attention) 각각에 맞는 Host 풀을 사용한다.
HiCacheController: 캐시 컨트롤러
GPU↔Host↔Storage 간 데이터 이동을 관리하는 컨트롤러를 초기화한다.
self.cache_controller = HiCacheController(
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
mem_pool_host=self.decode_host_mem_pool,
page_size=self.page_size,
tp_group=tp_group,
io_backend=server_args.hicache_io_backend,
storage_backend=server_args.hicache_storage_backend,
model_name=server_args.served_model_name,
)
점진적 오프로딩: offload_kv_cache
Decode 중 증분적으로 생성되는 KV 캐시를 offload_stride 단위로 GPU에서 Host로 이동시킨다.
def offload_kv_cache(self, req) -> bool:
all_tokens = req.origin_input_ids + req.output_ids[:-1]
prefill_offloaded_len = len(req.origin_input_ids) // self.page_size * self.page_size
state = self.offloaded_state.get(req.rid)
if state is None:
state = OffloadedState(
prefill_len=prefill_offloaded_len,
inc_len=0,
last_hash=last_prefill_hash,
)
self.offloaded_state[req.rid] = state
incremental_total = len(all_tokens) - state.prefill_len
incremental_new = incremental_total - state.inc_len
incremental_aligned_len = incremental_new // self.offload_stride * self.offload_stride
if incremental_aligned_len == 0:
return False
OffloadedState는 각 요청의 오프로딩 진행 상태를 추적한다. Prefill에서 받은 부분(prefill_len)과 Decode에서 생성된 증분(inc_len)을 분리 관리한다.
비동기 Host 전송
실제 GPU→Host 전송은 cache_controller.write를 통해 비동기로 수행된다.
host_indices = self.cache_controller.write(
device_indices=incremental_indices.long(),
node_id=ack_id,
)
if host_indices is None:
logger.error(f"Not enough host memory for request {req.rid}")
return False
self.ongoing_offload[ack_id] = (req, host_indices, incremental_tokens, time.time(), start, end)
state.inc_len += incremental_aligned_len
오프로딩 진행 확인
TP 환경에서 모든 랭크의 진행 상태를 동기화한 뒤 완료된 전송을 처리한다.
def check_offload_progress(self):
qsizes = torch.tensor([
len(cc.ack_write_queue),
cc.ack_backup_queue.qsize(),
], dtype=torch.int)
if self.tp_world_size > 1:
torch.distributed.all_reduce(qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group)
n_write, n_backup = map(int, qsizes.tolist())
self._check_offload_progress(n_write)
self._check_backup_progress(n_backup)
GPU 메모리 조기 해제
오프로딩이 완료되면 해당 구간의 GPU 메모리를 즉시 해제한다.
def _check_offload_progress(self, finish_count):
while finish_count > 0:
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
finish_event.synchronize()
for ack_id in ack_list:
req, host_indices, incremental_tokens, start_time, start, end = self.ongoing_offload.pop(ack_id)
if req.finished():
self._release_finished_req(req, start)
else:
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, start:end]
self.token_to_kv_pool_allocator.free(kv_indices)
이를 통해 GPU에는 최근 토큰의 KV 캐시만 유지하고, 과거 토큰은 Host에 보관한다.
스토리지 백업
Host 메모리의 KV 캐시를 외부 스토리지로 비동기 백업하여 Host 메모리도 회수한다.
def _trigger_backup(self, req, host_indices, incremental_tokens, start_time, prior_hash):
page_hashes = self._compute_prefix_hash(incremental_tokens, prior_hash)
ack_id = self.cache_controller.write_storage(
host_indices, incremental_tokens, hash_value=page_hashes,
)
self.ongoing_backup[ack_id] = (req.rid, host_indices, start_time)
Prefix hash를 사용하여 동일 프리픽스 중복 저장을 방지한다.
Finalize: 미정렬 꼬리 처리
요청 완료 시 offload_stride에 정렬되지 않은 나머지 KV 캐시를 정리한다.
def finalize_release_on_finish(self, req):
state = self.offloaded_state.get(req.rid)
if prefill_len > 0 and inc_len == 0:
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
self.token_to_kv_pool_allocator.free(token_indices[:prefill_len])
start_offset = prefill_len + inc_len
self._release_finished_req(req, start_offset)
설계 근거
offload_stride와 페이지 정렬
오프로딩 단위(offload_stride)는 페이지 크기의 배수로 설정된다. 이는 GPU 메모리 해제를 페이지 단위로 수행하기 위함이다.
env_stride = envs.SGLANG_HICACHE_DECODE_OFFLOAD_STRIDE.get()
self.offload_stride = max(self.page_size, (env_stride // self.page_size) * self.page_size)
Prefix Hash 기반 중복 제거
동일 프리픽스를 공유하는 요청들의 KV 캐시를 스토리지에 한 번만 저장한다. 해시 체인 구조로 이전 페이지의 해시를 다음 페이지 해시 계산에 포함시킨다.
def _compute_prefix_hash(self, tokens, prior_hash=""):
page_hashes = []
last_hash = prior_hash
for offset in range(0, len(tokens), self.page_size):
page_tokens = tokens[offset : offset + self.page_size]
last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
page_hashes.append(last_hash)
return page_hashes
관련 포스트
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Disaggregated Decode 서버: 디코드 전용 서버 구현
- 현재글 : [SGLang] KV Cache Offloading: Decode 중 메모리 오프로딩
- 다음글 [SGLang] Disaggregation 커넥터: Mooncake, NIXL, MORI 전송 엔진
댓글