[SGLang] GPU Memory Pool: 블록 기반 KV 캐시 메모리 할당
들어가며
LLM inference에서 KV 캐시는 GPU 메모리의 대부분을 차지한다. 매 요청마다 torch.malloc으로 메모리를 할당하고 해제하면, GPU 메모리 단편화와 커널 오버헤드가 심각한 병목이 된다. SGLang은 이 문제를 2단계 메모리 풀 아키텍처로 해결한다.
이 글에서는 python/sglang/srt/mem_cache/memory_pool.py를 중심으로 Memory Pool의 설계를 분석한다.
2단계 메모리 풀 아키텍처
SGLang의 메모리 관리는 세 개의 계층으로 구성된다.
┌─────────────────────────────────────────────────────┐
│ Layer 1: ReqToTokenPool │
│ 요청 → 토큰 위치 매핑 (req_pool_idx → token indices)│
│ [req_to_token: Tensor(size, max_context_len)] │
├─────────────────────────────────────────────────────┤
│ Layer 2: TokenToKVPoolAllocator │
│ 토큰 위치 → KV 캐시 인덱스 할당 │
│ [free_pages: Tensor] → alloc/free │
├─────────────────────────────────────────────────────┤
│ Layer 3: KVCache (MHATokenToKVPool 등) │
│ 실제 GPU 메모리에 K, V 텐서 저장 │
│ [k_buffer, v_buffer: Tensor(size, head, dim)] │
└─────────────────────────────────────────────────────┘
Layer 1이 요청을 토큰 위치에 매핑하고, Layer 2가 토큰 위치를 실제 KV 캐시 슬롯에 매핑하고, Layer 3이 물리적 GPU 메모리를 보유한다.
ReqToTokenPool: 요청-토큰 매핑
각 요청이 사용하는 KV 캐시 토큰 위치를 관리한다.
class ReqToTokenPool:
def __init__(self, size, max_context_len, device, enable_memory_saver):
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size))
def alloc(self, reqs: list[Req]) -> Optional[List[int]]:
# 이미 slot을 가진 요청(chunked prefill) 재사용
reusing = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None]
need_size = len(reqs) - len(reusing)
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
offset = 0
for r in reqs:
if r.req_pool_idx is None:
r.req_pool_idx = select_index[offset]
offset += 1
return [r.req_pool_idx for r in reqs]
핵심은 chunked prefill 지원이다. 긴 프롬프트를 여러 chunk로 나눠 처리할 때, 같은 요청은 이미 할당된 slot을 재사용한다.
KVCache: 물리적 GPU 메모리
KVCache는 추상 클래스로, 다양한 attention 아키텍처를 지원한다.
class KVCache(abc.ABC):
def __init__(self, size, page_size, dtype, layer_num, device, ...):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.store_dtype = (torch.uint8
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn)
else dtype)
FP8 양자화 KV 캐시를 지원하기 위해 store_dtype을 별도로 관리한다. FP8 타입에 대해 Tensor.index_put이 지원되지 않으므로 torch.uint8로 저장한다.
MHA(Multi-Head Attention) KV 캐시의 실제 구현을 보자.
class MHATokenToKVPool(KVCache):
def __init__(self, size, page_size, dtype, head_num, head_dim,
layer_num, device, ...):
super().__init__(size, page_size, dtype, layer_num, device, ...)
self.head_num = head_num
self.head_dim = head_dim
self._create_buffers()
self.row_dim = self.head_num * self.head_dim
self.same_kv_dim = self.head_dim == self.v_head_dim
set_kv_buffer는 KV 캐시에 데이터를 쓰는 핵심 연산인데, 성능을 위해 JIT 커널을 우선 사용한다.
def _set_kv_buffer_impl(k, v, k_cache, v_cache, indices, row_dim,
store_dtype, device_module, alt_stream, same_kv_dim):
row_bytes = row_dim * store_dtype.itemsize
if (_is_cuda or _is_hip) and same_kv_dim and can_use_store_cache(row_bytes):
return store_cache(
k.view(-1, row_dim), v.view(-1, row_dim),
k_cache.view(-1, row_dim), v_cache.view(-1, row_dim),
indices, row_bytes=row_bytes,
)
# fallback: CUDA Graph 모드에서는 alt_stream 활용
if get_is_capture_mode() and alt_stream is not None:
current_stream = device_module.current_stream()
alt_stream.wait_stream(current_stream)
k_cache[indices] = k
with device_module.stream(alt_stream):
v_cache[indices] = v
current_stream.wait_stream(alt_stream)
else:
k_cache[indices] = k
v_cache[indices] = v
K와 V를 별도 스트림에서 병렬로 쓰는 최적화가 CUDA Graph 캡처 모드에서 적용된다. 이는 메모리 대역폭을 2배 활용하는 효과가 있다.
HybridReqToTokenPool: Mamba 지원
Transformer-Mamba 하이브리드 모델을 위한 확장이다.
class HybridReqToTokenPool(ReqToTokenPool):
def alloc(self, reqs: List["Req"]) -> Optional[List[int]]:
select_index = super().alloc(reqs)
if select_index is None:
return None
mamba_indices = []
for req in reqs:
if req.mamba_pool_idx is not None:
mid = req.mamba_pool_idx # radix cache에서 재사용
else:
mid = self.mamba_pool.alloc(1)
req.mamba_pool_idx = mid
mamba_indices.append(mid)
mamba_index_tensor = torch.stack(mamba_indices).to(dtype=torch.int32)
self.req_index_to_mamba_index_mapping[select_index] = mamba_index_tensor
return select_index
KV 캐시(Transformer 레이어용)와 Mamba 상태(conv_state, temporal_state)를 동시에 관리한다. mamba_pool은 별도의 MambaPool 인스턴스로 Mamba 레이어의 상태를 할당/해제한다.
MambaPool: 선형 상태 관리
Mamba 모델의 conv_state와 temporal_state를 관리하는 전용 풀이다.
class MambaPool:
@dataclass(frozen=True, kw_only=True)
class State:
conv: List[torch.Tensor]
temporal: torch.Tensor
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
# 할당 시점에 즉시 초기화 (GPU zero 텐서 expand)
for i in range(len(self.mamba_cache.conv)):
t = self.mamba_cache.conv[i]
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:])
t[:, select_index] = z
return select_index
할당 시점에 상태를 0으로 초기화하는 것이 핵심이다. 스칼라 GPU zero 텐서를 expand하여 CPU-GPU 동기화 없이 효율적으로 초기화한다.
설계 근거
사전 할당 방식을 선택한 이유: GPU 메모리의 동적 할당은 비용이 크다. 서버 시작 시 전체 KV 캐시 메모리를 한 번에 할당하고, 이후에는 인덱스 기반으로 슬롯을 관리한다. free_pages 텐서에서 앞부분을 잘라내는 것이 alloc의 전부이므로 O(1)에 가깝다.
2단계 분리의 이유: 요청-토큰 매핑(ReqToTokenPool)과 토큰-KV 매핑(Allocator)을 분리하면, Radix Tree가 KV 인덱스를 공유할 때 ReqToTokenPool은 수정 없이 동일 인덱스를 참조할 수 있다. 즉, 프리픽스 캐싱의 이점을 자연스럽게 흡수한다.
alt_stream 최적화: CUDA Graph 캡처 시 K, V를 별도 스트림에서 병렬로 쓴다. 메모리 바운드 연산이므로 스트림 병렬화가 실질적 성능 향상을 제공한다.
관련 포스트
- RadixAttention: Radix Tree 기반 프리픽스 캐싱의 핵심
- Allocator: 토큰-KV 풀 할당 전략의 설계
- HiRadixCache: 계층적 GPU/CPU/Disk KV 캐시
참고
관련 포스트
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
- [sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서
- [sglang] NixlKVManager 성능 향상: 비동기 및 멀티스레드 KV 전송 도입
SGLang 의 다른글
- 이전글 [SGLang] C++ Radix Tree: 고성능 캐시를 위한 네이티브 구현
- 현재글 : [SGLang] GPU Memory Pool: 블록 기반 KV 캐시 메모리 할당
- 다음글 [SGLang] Allocator: 토큰-KV 풀 할당 전략의 설계
댓글