본문으로 건너뛰기

[vLLM] Sleep Mode: GPU 메모리 동적 관리

들어가며

GPU 메모리는 비싸다. 서빙 중에 요청이 없는 유휴 시간에도 모델 가중치가 GPU를 점유하고 있으면 낭비다. vLLM의 Sleep Mode는 vllm/device_allocator/cumem.py에서 CuMem allocator를 사용하여, 유휴 시 GPU 메모리를 CPU로 오프로딩하고 필요할 때 복원하는 기능을 제공한다.

공식 문서

vLLM 공식 문서: Sleep Mode

핵심 구조/코드 분석

CuMemAllocator: 싱글톤 설계

class CuMemAllocator:
    """메모리 풀의 텐서를 관리하는 싱글톤 클래스.
    sleep 시 태그별로 오프로딩 또는 폐기.
    wake_up 시 오프로딩된 데이터를 복원."""

    instance: "CuMemAllocator | None" = None

    @staticmethod
    def get_instance() -> "CuMemAllocator":
        assert cumem_available
        if CuMemAllocator.instance is None:
            CuMemAllocator.instance = CuMemAllocator()
        return CuMemAllocator.instance

반드시 싱글톤이어야 하는 이유가 있다. C 확장이 전역 변수에 콜백 함수를 저장하므로, 여러 인스턴스를 만들면 콜백이 덮어써져 free 시 오동작한다.

태그 기반 메모리 관리

@contextmanager
def use_memory_pool(self, tag: str | None = None):
    if tag is None:
        tag = CuMemAllocator.default_tag
    old_tag = self.current_tag
    self.current_tag = tag
    with use_memory_pool_with_allocator(
        self.python_malloc_callback, self.python_free_callback
    ) as data:
        self.allocator_and_pools[tag] = data
        yield

use_memory_pool(tag) 컨텍스트 안에서 할당된 모든 텐서에 태그가 부여된다. sleep 시 태그별로 동작을 달리한다 - 모델 가중치는 CPU로 백업하고, KV 캐시는 폐기한다.

Sleep: GPU -> CPU 오프로딩

def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
    if offload_tags is None:
        offload_tags = (CuMemAllocator.default_tag,)

    for ptr, data in self.pointer_to_data.items():
        handle = data.handle
        total_bytes += handle[1]
        if data.tag in offload_tags:
            # CPU pinned 메모리로 백업
            cpu_backup_tensor = torch.empty(size_in_bytes, dtype=torch.uint8,
                device="cpu", pin_memory=is_pin_memory_available())
            libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes)
            data.cpu_backup_tensor = cpu_backup_tensor
        # GPU 메모리 해제 (unmap + release)
        unmap_and_release(handle)

    gc.collect()
    torch.cuda.empty_cache()

offload_tags에 포함된 메모리는 CPU로 백업 후 GPU에서 해제하고, 나머지는 백업 없이 바로 해제(폐기)한다. gc.collect()empty_cache()로 완전히 정리한다.

Wake Up: CPU -> GPU 복원

def wake_up(self, tags: list[str] | None = None) -> None:
    for ptr, data in self.pointer_to_data.items():
        if tags is None or data.tag in tags:
            handle = data.handle
            create_and_map(handle)  # GPU 메모리 재할당
            if data.cpu_backup_tensor is not None:
                cpu_ptr = data.cpu_backup_tensor.data_ptr()
                libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)  # CPU -> GPU 복원
                data.cpu_backup_tensor = None

create_and_map으로 동일한 주소에 GPU 메모리를 다시 매핑하고, CPU 백업에서 데이터를 복원한다. 기존 텐서들의 포인터가 변하지 않으므로 모델 코드 수정이 필요 없다.

PyTorch Pluggable Allocator

def get_pluggable_allocator(python_malloc_fn, python_free_func):
    init_module(python_malloc_fn, python_free_func)
    new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
        lib_name, "my_malloc", "my_free"
    )
    return new_alloc

PyTorch의 pluggable allocator API를 사용하여 커스텀 메모리 할당기를 등록한다. C 확장(cumem_allocator)이 실제 CUDA Driver API(cuMemCreate, cuMemMap, cuMemUnmap, cuMemRelease)를 호출한다.

할당 추적

def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
    py_d_mem = allocation_handle[2]
    self.pointer_to_data[py_d_mem] = AllocationData(allocation_handle, self.current_tag)

def _python_free_callback(self, ptr: int) -> HandleType:
    data = self.pointer_to_data.pop(ptr)
    return data.handle

모든 할당/해제를 pointer_to_data 딕셔너리로 추적한다. 포인터 -> (핸들, 태그, CPU 백업) 매핑으로, sleep/wake_up 시 어떤 메모리를 어떻게 처리할지 결정한다.

왜 이 설계인가

  1. CuMem Virtual Memory API 사용: 일반적인 cudaMalloc/cudaFree는 같은 주소에 재할당할 수 없다. CuMem의 가상 메모리 관리(cuMemMap/cuMemUnmap)를 사용하면 물리 메모리만 해제하고 가상 주소를 유지할 수 있어, wake_up 시 동일 주소에 재매핑이 가능하다.

  2. C 확장의 필요성: 코드 주석에 cuda-python 패키지와 ctypes 래퍼를 시도했지만 CUDA 컨텍스트 불일치로 실패했다고 적혀 있다. C 확장이 유일한 성공적 접근이었다.

  3. 태그 기반 선택적 오프로딩: 모델 가중치는 비용이 크므로 CPU로 백업하고, KV 캐시나 중간 텐서는 재계산이 가능하므로 폐기한다. 이 구분을 태그 시스템으로 유연하게 구현했다.

참고 자료

댓글

관련 포스트

vLLM 의 다른글