본문으로 건너뛰기

[SGLang] KV Cache Offloading: Decode 중 메모리 오프로딩

들어가며

Disaggregated Decode 서버에서 긴 시퀀스를 처리할 때, GPU 메모리에 모든 KV 캐시를 유지하기 어려울 수 있다. SGLang의 DecodeKVCacheOffloadManager는 Prefill에서 전송받은 KV 캐시를 점진적으로 CPU(Host) 메모리로 오프로딩하고, 필요 시 외부 스토리지에 백업하는 3단계 계층 구조를 구현한다.

구조도

┌───────────────────────────────────────────────┐
│  GPU (Device)                                 │
│  ┌─────────────┐                              │
│  │ KV Cache    │ ──offload──► ┌─────────────┐ │
│  │ (Active)    │              │ Host Memory │ │
│  └─────────────┘              │ (CPU Pool)  │ │
│                               └──────┬──────┘ │
│                                      │        │
│                               backup │        │
│                                      ▼        │
│                            ┌──────────────┐   │
│                            │  Storage     │   │
│                            │  (External)  │   │
│                            └──────────────┘   │
└───────────────────────────────────────────────┘

타임라인:
  Prefill KV 수신 → Decode 시작 → 점진적 offload → GPU free → backup → host free

핵심 코드 분석

초기화: Host 메모리 풀 생성

python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py에서 KV 캐시 타입에 따라 Host 메모리 풀을 생성한다.

class DecodeKVCacheOffloadManager:
    def __init__(self, req_to_token_pool, token_to_kv_pool_allocator,
                 tp_group, tree_cache, server_args):
        kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
        if isinstance(kv_cache, MHATokenToKVPool):
            self.decode_host_mem_pool = MHATokenToKVPoolHost(
                kv_cache, server_args.hicache_ratio,
                server_args.hicache_size, self.page_size, ...)
        elif isinstance(kv_cache, MLATokenToKVPool):
            self.decode_host_mem_pool = MLATokenToKVPoolHost(
                kv_cache, server_args.hicache_ratio, ...)

MHA(Multi-Head Attention)와 MLA(Multi-Latent Attention) 각각에 맞는 Host 풀을 사용한다.

HiCacheController: 캐시 컨트롤러

GPU↔Host↔Storage 간 데이터 이동을 관리하는 컨트롤러를 초기화한다.

self.cache_controller = HiCacheController(
    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
    mem_pool_host=self.decode_host_mem_pool,
    page_size=self.page_size,
    tp_group=tp_group,
    io_backend=server_args.hicache_io_backend,
    storage_backend=server_args.hicache_storage_backend,
    model_name=server_args.served_model_name,
)

점진적 오프로딩: offload_kv_cache

Decode 중 증분적으로 생성되는 KV 캐시를 offload_stride 단위로 GPU에서 Host로 이동시킨다.

def offload_kv_cache(self, req) -> bool:
    all_tokens = req.origin_input_ids + req.output_ids[:-1]
    prefill_offloaded_len = len(req.origin_input_ids) // self.page_size * self.page_size

    state = self.offloaded_state.get(req.rid)
    if state is None:
        state = OffloadedState(
            prefill_len=prefill_offloaded_len,
            inc_len=0,
            last_hash=last_prefill_hash,
        )
        self.offloaded_state[req.rid] = state

    incremental_total = len(all_tokens) - state.prefill_len
    incremental_new = incremental_total - state.inc_len
    incremental_aligned_len = incremental_new // self.offload_stride * self.offload_stride

    if incremental_aligned_len == 0:
        return False

OffloadedState는 각 요청의 오프로딩 진행 상태를 추적한다. Prefill에서 받은 부분(prefill_len)과 Decode에서 생성된 증분(inc_len)을 분리 관리한다.

비동기 Host 전송

실제 GPU→Host 전송은 cache_controller.write를 통해 비동기로 수행된다.

host_indices = self.cache_controller.write(
    device_indices=incremental_indices.long(),
    node_id=ack_id,
)
if host_indices is None:
    logger.error(f"Not enough host memory for request {req.rid}")
    return False
self.ongoing_offload[ack_id] = (req, host_indices, incremental_tokens, time.time(), start, end)
state.inc_len += incremental_aligned_len

오프로딩 진행 확인

TP 환경에서 모든 랭크의 진행 상태를 동기화한 뒤 완료된 전송을 처리한다.

def check_offload_progress(self):
    qsizes = torch.tensor([
        len(cc.ack_write_queue),
        cc.ack_backup_queue.qsize(),
    ], dtype=torch.int)
    if self.tp_world_size > 1:
        torch.distributed.all_reduce(qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group)
    n_write, n_backup = map(int, qsizes.tolist())
    self._check_offload_progress(n_write)
    self._check_backup_progress(n_backup)

GPU 메모리 조기 해제

오프로딩이 완료되면 해당 구간의 GPU 메모리를 즉시 해제한다.

def _check_offload_progress(self, finish_count):
    while finish_count > 0:
        _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
        finish_event.synchronize()
        for ack_id in ack_list:
            req, host_indices, incremental_tokens, start_time, start, end = self.ongoing_offload.pop(ack_id)
            if req.finished():
                self._release_finished_req(req, start)
            else:
                kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, start:end]
                self.token_to_kv_pool_allocator.free(kv_indices)

이를 통해 GPU에는 최근 토큰의 KV 캐시만 유지하고, 과거 토큰은 Host에 보관한다.

스토리지 백업

Host 메모리의 KV 캐시를 외부 스토리지로 비동기 백업하여 Host 메모리도 회수한다.

def _trigger_backup(self, req, host_indices, incremental_tokens, start_time, prior_hash):
    page_hashes = self._compute_prefix_hash(incremental_tokens, prior_hash)
    ack_id = self.cache_controller.write_storage(
        host_indices, incremental_tokens, hash_value=page_hashes,
    )
    self.ongoing_backup[ack_id] = (req.rid, host_indices, start_time)

Prefix hash를 사용하여 동일 프리픽스 중복 저장을 방지한다.

Finalize: 미정렬 꼬리 처리

요청 완료 시 offload_stride에 정렬되지 않은 나머지 KV 캐시를 정리한다.

def finalize_release_on_finish(self, req):
    state = self.offloaded_state.get(req.rid)
    if prefill_len > 0 and inc_len == 0:
        token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
        self.token_to_kv_pool_allocator.free(token_indices[:prefill_len])
    start_offset = prefill_len + inc_len
    self._release_finished_req(req, start_offset)

설계 근거

offload_stride와 페이지 정렬

오프로딩 단위(offload_stride)는 페이지 크기의 배수로 설정된다. 이는 GPU 메모리 해제를 페이지 단위로 수행하기 위함이다.

env_stride = envs.SGLANG_HICACHE_DECODE_OFFLOAD_STRIDE.get()
self.offload_stride = max(self.page_size, (env_stride // self.page_size) * self.page_size)

Prefix Hash 기반 중복 제거

동일 프리픽스를 공유하는 요청들의 KV 캐시를 스토리지에 한 번만 저장한다. 해시 체인 구조로 이전 페이지의 해시를 다음 페이지 해시 계산에 포함시킨다.

def _compute_prefix_hash(self, tokens, prior_hash=""):
    page_hashes = []
    last_hash = prior_hash
    for offset in range(0, len(tokens), self.page_size):
        page_tokens = tokens[offset : offset + self.page_size]
        last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
        page_hashes.append(last_hash)
    return page_hashes

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글