본문으로 건너뛰기

[SGLang] Hybrid Cache Controller: GPU/CPU 하이브리드 캐시 관리

들어가며

GPU 메모리만으로 모든 KV Cache를 유지하기는 어렵다. 긴 컨텍스트 모델이나 동시 요청이 많은 환경에서는 GPU 메모리가 빠르게 소진된다. 그렇다고 evict하면 같은 prefix가 재요청될 때 처음부터 다시 계산해야 한다.

SGLang의 HybridCacheController는 GPU와 CPU 메모리를 계층적으로 관리하여 이 문제를 해결한다. GPU에서 evict된 KV Cache를 CPU로 백업하고, 필요하면 다시 GPU로 로드한다. python/sglang/srt/mem_cache/hybrid_cache/hybrid_cache_controller.py를 중심으로 분석한다.

구조도

┌─────────────────────────────────────────────────────┐
│                    GPU (Device)                      │
│  ┌───────────────────────────────────────────────┐  │
│  │   mem_pool_device (KV Cache + Mamba State)     │  │
│  │   token_to_kv_pool_allocator                   │  │
│  └───────────────────┬───────────────────────────┘  │
│                      │ write_stream / load_stream    │
└──────────────────────┼──────────────────────────────┘
                       ▼
┌─────────────────────────────────────────────────────┐
│                    CPU (Host)                         │
│  ┌───────────────────────────────────────────────┐  │
│  │   mem_pool_host (HostKVCache)                  │  │
│  │   ├── kv_pool (KV 백업)                        │  │
│  │   └── extra pools (Mamba 등)                   │  │
│  └───────────────────┬───────────────────────────┘  │
│                      │ storage_backend (optional)    │
└──────────────────────┼──────────────────────────────┘
                       ▼
┌─────────────────────────────────────────────────────┐
│              External Storage (Optional)             │
│  NIXL / Mooncake / 3FS / File / EIC                 │
└─────────────────────────────────────────────────────┘

CacheOperation: 전송 단위

GPU-CPU 간 데이터 이동의 기본 단위는 CacheOperation이다. KV 인덱스와 함께 Mamba 같은 추가 풀의 전송 정보도 포함한다.

class CacheOperation(BaseCacheOperation):
    def __init__(
        self,
        host_indices: torch.Tensor,
        device_indices: torch.Tensor,
        node_id: int,
        priority: Optional[int] = None,
        pool_transfers: Optional[list[PoolTransfer]] = None,
    ):
        super().__init__(host_indices, device_indices, node_id, priority)
        self.pool_transfers = pool_transfers

여러 연산을 하나로 합치는 merge_ops가 핵심이다. 작은 전송을 모아서 큰 전송으로 만들어 GPU-CPU 간 대역폭을 효율적으로 사용한다.

@staticmethod
def merge_ops(ops: List["CacheOperation"]) -> "CacheOperation":
    if len(ops) == 1:
        return ops[0]
    host_indices = torch.cat([op.host_indices for op in ops])
    device_indices = torch.cat([op.device_indices for op in ops])
    priority = min(op.priority for op in ops)
    merged = CacheOperation(
        host_indices, device_indices, -1, priority,
        pool_transfers=CacheOperation.merge_pool_transfers(ops),
    )
    return merged

Write: GPU에서 CPU로 백업

write()는 GPU의 KV Cache를 CPU로 복사한다. host 메모리를 할당하고 큐에 넣은 뒤, start_writing()에서 비동기 복사를 실행한다.

def write(self, device_indices, priority=None, node_id=-1,
          extra_pools=None) -> Optional[torch.Tensor]:
    host_indices = self.mem_pool_host.alloc(len(device_indices))
    if host_indices is None:
        return None
    pool_transfers = self._resolve_pool_transfers_allocation(
        extra_pools, alloc_host=True
    )
    self.write_queue.append(
        CacheOperation(host_indices, device_indices, node_id,
                       priority, pool_transfers=pool_transfers)
    )
    self.start_writing()
    return host_indices

start_writing()에서는 별도의 CUDA stream에서 비동기 복사를 수행한다. 메인 연산 stream을 차단하지 않는다.

def start_writing(self) -> None:
    if not self.write_queue:
        return
    op = CacheOperation.merge_ops(self.write_queue)
    host_indices, device_indices = self.move_indices(op)
    self.write_queue.clear()
    start_event = device_module.Event()
    finish_event = device_module.Event()
    start_event.record()
    with device_module.stream(self.write_stream):
        start_event.wait(self.write_stream)
        self.mem_pool_host.backup_from_device_all_layer(
            self.mem_pool_device, host_indices, device_indices,
            self.io_backend, pool_transfers=op.pool_transfers,
        )
        finish_event.record()

Load: CPU에서 GPU로 복원

load()는 CPU에 백업된 KV Cache를 GPU로 복원한다. device 메모리를 할당하고 큐에 넣는다. start_loading()에서 레이어 단위로 복사한다.

def start_loading(self) -> int:
    if not self.load_queue:
        return -1
    producer_id = self.layer_done_counter.update_producer()
    op = CacheOperation.merge_ops(self.load_queue)
    self.load_queue.clear()
    with device_module.stream(self.load_stream):
        for i in range(self.layer_num):
            self.mem_pool_host.load_to_device_per_layer(
                self.mem_pool_device, host_indices,
                device_indices, i, self.io_backend,
                pool_transfers=op.pool_transfers,
            )
            producer_event.complete(i)

레이어 단위 로딩은 pipelining을 가능하게 한다. 0번 레이어가 GPU로 복사되면, GPU는 0번 레이어의 attention 연산을 시작할 수 있다. 1번 레이어가 복사되는 동안 0번 레이어의 연산이 병렬로 진행된다.

Prefetch: 예측 기반 선제 로딩

prefetch()는 아직 필요하지 않지만 곧 필요할 것으로 예상되는 KV Cache를 미리 로드한다.

def prefetch(self, request_id, host_indices, new_input_tokens,
             last_hash=None, prefix_keys=None,
             extra_pools=None) -> PrefetchOperation:
    operation = PrefetchOperation(
        request_id, host_indices, new_input_tokens,
        last_hash, prefix_keys=prefix_keys,
        pool_transfers=extra_pools,
    )
    self.prefetch_queue.put(operation)
    return operation

PrefetchOperation은 thread-safe한 종료 메커니즘을 가진다. 예측이 틀렸을 때(요청이 취소됐을 때) 진행 중인 prefetch를 안전하게 중단할 수 있다.

class PrefetchOperation(StorageOperation):
    def increment(self, num_tokens: int):
        with self._lock:
            if self._terminated_flag:
                return False
            self.completed_tokens += num_tokens
            return True

    def mark_terminate(self):
        with self._lock:
            self._terminated_flag = True

PoolTransfer: 다중 풀 관리

하이브리드 모델(Transformer + Mamba)은 KV Cache 외에 Mamba 상태 풀도 함께 전송해야 한다. PoolTransfer가 이를 추상화한다.

@dataclass
class PoolTransfer:
    name: PoolName          # "kv" 또는 "mamba"
    host_indices: Optional[torch.Tensor] = None
    device_indices: Optional[torch.Tensor] = None
    keys: Optional[List[str]] = None
    hit_policy: PoolHitPolicy = PoolHitPolicy.ALL_PAGES

_resolve_pool_transfers_allocation()은 PoolTransfer에서 누락된 인덱스를 자동으로 할당한다. host에서 device로 로드할 때 device_indices가 None이면 자동 할당하고, 실패 시 이전에 할당한 모든 인덱스를 롤백한다.

데이터 흐름 요약

요청 완료 시:
  GPU KV → write() → CPU host
         → write_storage() → External Storage (optional)

새 요청, prefix cache hit (GPU):
  GPU KV 직접 재활용 (이동 없음)

새 요청, prefix cache hit (CPU):
  CPU host → load() → GPU KV

새 요청, prefix cache hit (Storage):
  External Storage → prefetch() → CPU host → GPU KV

설계 근거

HybridCacheController가 BaseHiCacheController를 상속하면서 확장하는 핵심은 pool_transfers다. 기본 컨트롤러는 KV Cache만 처리하지만, 하이브리드 컨트롤러는 Mamba 상태 등 추가 풀을 동일한 파이프라인으로 전송한다. 이는 하이브리드 모델(Jamba 등)에서 KV와 Mamba 상태의 일관성을 보장한다.

관련 포스트

  • Mamba Radix Cache: SSM 모델을 위한 상태 캐싱
  • 외부 스토리지 백엔드: LMCache, 3FS, Mooncake, NIXL
  • 캐시 Eviction 정책: LRU, LFU, FIFO 비교 분석

참고

댓글

관련 포스트

SGLang 의 다른글