본문으로 건너뛰기

[SGLang] GPU Memory Pool: 블록 기반 KV 캐시 메모리 할당

들어가며

LLM inference에서 KV 캐시는 GPU 메모리의 대부분을 차지한다. 매 요청마다 torch.malloc으로 메모리를 할당하고 해제하면, GPU 메모리 단편화와 커널 오버헤드가 심각한 병목이 된다. SGLang은 이 문제를 2단계 메모리 풀 아키텍처로 해결한다.

이 글에서는 python/sglang/srt/mem_cache/memory_pool.py를 중심으로 Memory Pool의 설계를 분석한다.

2단계 메모리 풀 아키텍처

SGLang의 메모리 관리는 세 개의 계층으로 구성된다.

┌─────────────────────────────────────────────────────┐
  Layer 1: ReqToTokenPool                            
  요청  토큰 위치 매핑 (req_pool_idx  token indices)│
  [req_to_token: Tensor(size, max_context_len)]      
├─────────────────────────────────────────────────────┤
  Layer 2: TokenToKVPoolAllocator                    
  토큰 위치  KV 캐시 인덱스 할당                     
  [free_pages: Tensor]  alloc/free                  
├─────────────────────────────────────────────────────┤
  Layer 3: KVCache (MHATokenToKVPool 등)             
  실제 GPU 메모리에 K, V 텐서 저장                    
  [k_buffer, v_buffer: Tensor(size, head, dim)]      
└─────────────────────────────────────────────────────┘

Layer 1이 요청을 토큰 위치에 매핑하고, Layer 2가 토큰 위치를 실제 KV 캐시 슬롯에 매핑하고, Layer 3이 물리적 GPU 메모리를 보유한다.

ReqToTokenPool: 요청-토큰 매핑

각 요청이 사용하는 KV 캐시 토큰 위치를 관리한다.

class ReqToTokenPool:
    def __init__(self, size, max_context_len, device, enable_memory_saver):
        self.req_to_token = torch.zeros(
            (size, max_context_len), dtype=torch.int32, device=device
        )
        self.free_slots = list(range(size))

    def alloc(self, reqs: list[Req]) -> Optional[List[int]]:
        # 이미 slot을 가진 요청(chunked prefill) 재사용
        reusing = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None]
        need_size = len(reqs) - len(reusing)
        if need_size > len(self.free_slots):
            return None
        select_index = self.free_slots[:need_size]
        self.free_slots = self.free_slots[need_size:]
        offset = 0
        for r in reqs:
            if r.req_pool_idx is None:
                r.req_pool_idx = select_index[offset]
                offset += 1
        return [r.req_pool_idx for r in reqs]

핵심은 chunked prefill 지원이다. 긴 프롬프트를 여러 chunk로 나눠 처리할 때, 같은 요청은 이미 할당된 slot을 재사용한다.

KVCache: 물리적 GPU 메모리

KVCache는 추상 클래스로, 다양한 attention 아키텍처를 지원한다.

class KVCache(abc.ABC):
    def __init__(self, size, page_size, dtype, layer_num, device, ...):
        self.size = size
        self.page_size = page_size
        self.dtype = dtype
        self.store_dtype = (torch.uint8
            if dtype in (torch.float8_e5m2, torch.float8_e4m3fn)
            else dtype)

FP8 양자화 KV 캐시를 지원하기 위해 store_dtype을 별도로 관리한다. FP8 타입에 대해 Tensor.index_put이 지원되지 않으므로 torch.uint8로 저장한다.

MHA(Multi-Head Attention) KV 캐시의 실제 구현을 보자.

class MHATokenToKVPool(KVCache):
    def __init__(self, size, page_size, dtype, head_num, head_dim,
                 layer_num, device, ...):
        super().__init__(size, page_size, dtype, layer_num, device, ...)
        self.head_num = head_num
        self.head_dim = head_dim
        self._create_buffers()
        self.row_dim = self.head_num * self.head_dim
        self.same_kv_dim = self.head_dim == self.v_head_dim

set_kv_buffer는 KV 캐시에 데이터를 쓰는 핵심 연산인데, 성능을 위해 JIT 커널을 우선 사용한다.

def _set_kv_buffer_impl(k, v, k_cache, v_cache, indices, row_dim,
                         store_dtype, device_module, alt_stream, same_kv_dim):
    row_bytes = row_dim * store_dtype.itemsize
    if (_is_cuda or _is_hip) and same_kv_dim and can_use_store_cache(row_bytes):
        return store_cache(
            k.view(-1, row_dim), v.view(-1, row_dim),
            k_cache.view(-1, row_dim), v_cache.view(-1, row_dim),
            indices, row_bytes=row_bytes,
        )
    # fallback: CUDA Graph 모드에서는 alt_stream 활용
    if get_is_capture_mode() and alt_stream is not None:
        current_stream = device_module.current_stream()
        alt_stream.wait_stream(current_stream)
        k_cache[indices] = k
        with device_module.stream(alt_stream):
            v_cache[indices] = v
        current_stream.wait_stream(alt_stream)
    else:
        k_cache[indices] = k
        v_cache[indices] = v

K와 V를 별도 스트림에서 병렬로 쓰는 최적화가 CUDA Graph 캡처 모드에서 적용된다. 이는 메모리 대역폭을 2배 활용하는 효과가 있다.

HybridReqToTokenPool: Mamba 지원

Transformer-Mamba 하이브리드 모델을 위한 확장이다.

class HybridReqToTokenPool(ReqToTokenPool):
    def alloc(self, reqs: List["Req"]) -> Optional[List[int]]:
        select_index = super().alloc(reqs)
        if select_index is None:
            return None

        mamba_indices = []
        for req in reqs:
            if req.mamba_pool_idx is not None:
                mid = req.mamba_pool_idx  # radix cache에서 재사용
            else:
                mid = self.mamba_pool.alloc(1)
                req.mamba_pool_idx = mid
            mamba_indices.append(mid)

        mamba_index_tensor = torch.stack(mamba_indices).to(dtype=torch.int32)
        self.req_index_to_mamba_index_mapping[select_index] = mamba_index_tensor
        return select_index

KV 캐시(Transformer 레이어용)와 Mamba 상태(conv_state, temporal_state)를 동시에 관리한다. mamba_pool은 별도의 MambaPool 인스턴스로 Mamba 레이어의 상태를 할당/해제한다.

MambaPool: 선형 상태 관리

Mamba 모델의 conv_state와 temporal_state를 관리하는 전용 풀이다.

class MambaPool:
    @dataclass(frozen=True, kw_only=True)
    class State:
        conv: List[torch.Tensor]
        temporal: torch.Tensor

    def alloc(self, need_size: int) -> Optional[torch.Tensor]:
        if need_size > len(self.free_slots):
            return None
        select_index = self.free_slots[:need_size]
        self.free_slots = self.free_slots[need_size:]
        # 할당 시점에 즉시 초기화 (GPU zero 텐서 expand)
        for i in range(len(self.mamba_cache.conv)):
            t = self.mamba_cache.conv[i]
            z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
                t.shape[0], need_size, *t.shape[2:])
            t[:, select_index] = z
        return select_index

할당 시점에 상태를 0으로 초기화하는 것이 핵심이다. 스칼라 GPU zero 텐서를 expand하여 CPU-GPU 동기화 없이 효율적으로 초기화한다.

설계 근거

사전 할당 방식을 선택한 이유: GPU 메모리의 동적 할당은 비용이 크다. 서버 시작 시 전체 KV 캐시 메모리를 한 번에 할당하고, 이후에는 인덱스 기반으로 슬롯을 관리한다. free_pages 텐서에서 앞부분을 잘라내는 것이 alloc의 전부이므로 O(1)에 가깝다.

2단계 분리의 이유: 요청-토큰 매핑(ReqToTokenPool)과 토큰-KV 매핑(Allocator)을 분리하면, Radix Tree가 KV 인덱스를 공유할 때 ReqToTokenPool은 수정 없이 동일 인덱스를 참조할 수 있다. 즉, 프리픽스 캐싱의 이점을 자연스럽게 흡수한다.

alt_stream 최적화: CUDA Graph 캡처 시 K, V를 별도 스트림에서 병렬로 쓴다. 메모리 바운드 연산이므로 스트림 병렬화가 실질적 성능 향상을 제공한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글