본문으로 건너뛰기

[vLLM] GPU Worker & InputBatch

들어가며

vLLM에서 실제 GPU 연산을 수행하는 것은 Worker이다. Worker는 모델 로딩, 디바이스 초기화, 메모리 프로파일링, 모델 실행의 전체 생명주기를 관리한다. 그리고 InputBatch는 스케줄러의 출력을 GPU가 즉시 처리할 수 있는 텐서 배치로 변환한다.

코드: vllm/v1/worker/gpu_worker.py, vllm/v1/worker/gpu/input_batch.py

핵심 구조/코드 분석

Worker 초기화

class Worker(WorkerBase):
    def __init__(
        self,
        vllm_config: VllmConfig,
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
    ):
        super().__init__(...)

        precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
        torch.set_float32_matmul_precision(precision)

        self.elastic_ep_executor = ElasticEPScalingExecutor(self)
        self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

Worker는 local_rank로 자신이 사용할 GPU를 결정하고, 분산 환경 초기화를 담당한다. VLLM_FLOAT32_MATMUL_PRECISION 환경변수로 matmul 정밀도를 제어할 수 있다.

디바이스 초기화

def init_device(self):
    # DP adjusted local rank
    dp_local_rank = self.parallel_config.data_parallel_rank_local
    tp_pp_world_size = (
        self.parallel_config.pipeline_parallel_size
        * self.parallel_config.tensor_parallel_size
    )
    self.local_rank += dp_local_rank * tp_pp_world_size

    self.device = torch.device(f"cuda:{self.local_rank}")
    torch.accelerator.set_device_index(self.device)

    init_worker_distributed_environment(
        self.vllm_config, self.rank, self.distributed_init_method,
        self.local_rank, current_platform.dist_backend,
    )
    set_random_seed(self.model_config.seed)

    self.init_snapshot = MemorySnapshot(device=self.device)
    self.requested_memory = request_memory(init_snapshot, self.cache_config)

Data Parallel과 Tensor Parallel이 결합될 때 local_rank를 조정하는 로직이 핵심이다. DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK 공식으로 GPU 인덱스를 계산한다. NCCL 초기화 후 메모리 스냅샷을 찍어 사용 가능한 GPU 메모리를 측정한다.

메모리 프로파일링

@torch.inference_mode()
def determine_available_memory(self) -> int:
    if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
        self.model_runner.profile_run()
        return kv_cache_memory_bytes

    # Profile peak memory usage
    ...

KV 캐시에 할당할 메모리를 결정하는 과정이다. 사용자가 명시적으로 지정하면 그 값을 사용하고, 그렇지 않으면 프로파일 실행을 통해 모델의 최대 메모리 사용량을 측정한 뒤 남은 메모리를 KV 캐시에 할당한다.

Sleep/Wake Up: GPU 메모리 관리

def sleep(self, level: int = 1) -> None:
    free_bytes_before_sleep = torch.cuda.mem_get_info()[0]

    if level == 2:
        model = self.model_runner.model
        self._sleep_saved_buffers = {
            name: buffer.cpu().clone()
            for name, buffer in model.named_buffers()
        }

    allocator = CuMemAllocator.get_instance()
    allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())

def wake_up(self, tags: list[str] | None = None) -> None:
    allocator = CuMemAllocator.get_instance()
    allocator.wake_up(tags)

    if len(self._sleep_saved_buffers):
        model = self.model_runner.model
        for name, buffer in model.named_buffers():
            if name in self._sleep_saved_buffers:
                buffer.data.copy_(self._sleep_saved_buffers[name].data)

Level 1 sleep은 가중치만 CPU로 오프로드하고, Level 2는 버퍼까지 포함한다. CuMemAllocator를 통해 CUDA 메모리를 태그 단위로 관리한다. Wake up 시 버퍼를 복원하고, FP8 KV 캐시의 스케일링 팩터도 재초기화한다.

InputBatch: GPU 텐서 배치 관리

@dataclass
class InputBatch:
    req_ids: list[str]
    num_reqs: int
    num_reqs_after_padding: int

    idx_mapping: torch.Tensor        # batch_idx -> req_state_idx
    num_scheduled_tokens: np.ndarray  # batch_idx -> num_scheduled_tokens
    num_tokens: int

    query_start_loc: torch.Tensor    # [num_reqs + 1]
    seq_lens: torch.Tensor           # [num_reqs]
    input_ids: torch.Tensor          # [num_tokens_after_padding]
    positions: torch.Tensor          # [num_tokens_after_padding]
    logits_indices: torch.Tensor     # [total_num_logits]

InputBatch는 스케줄러가 결정한 배치를 GPU 텐서로 표현한다. 핵심 필드들:

  • query_start_loc: 각 요청의 시작 위치 (prefix sum). Flash Attention에서 가변 길이 시퀀스를 처리하는 데 사용된다.
  • seq_lens: 각 시퀀스의 총 길이 (프리필 + 디코딩)
  • logits_indices: 로짓을 추출할 위치. 요청마다 마지막 토큰 위치를 가리킨다.

InputBuffers: 재사용 가능한 GPU 버퍼

class InputBuffers:
    def __init__(self, max_num_reqs, max_num_tokens, device):
        self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
        self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
        self.query_start_loc = torch.zeros(
            max_num_reqs + 1, dtype=torch.int32, device=device
        )
        self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)

매 이터레이션마다 새 텐서를 할당하면 CUDA 메모리 할당 오버헤드가 크다. InputBuffers는 최대 크기로 미리 할당된 버퍼를 제공하고, 슬라이싱으로 필요한 부분만 사용한다.

AsyncIntermediateTensors: 비동기 PP 통신

class AsyncIntermediateTensors(IntermediateTensors):
    def __init__(self, tensors, comm_handles=None, comm_postprocess=None):
        super().__init__(tensors)
        self._comm_handles = comm_handles
        self._comm_waited = False

    def __getattribute__(self, name):
        if name == "tensors" and not object.__getattribute__(self, "_comm_waited"):
            object.__getattribute__(self, "wait_for_comm")()
        return object.__getattribute__(self, name)

파이프라인 병렬화에서 스테이지 간 텐서 전송을 지연 대기한다. .tensors에 접근할 때 자동으로 통신 완료를 기다리므로, 연산과 통신을 겹칠 수 있다.

왜 이 설계인가

  1. 사전 할당 버퍼: InputBuffers로 GPU 메모리 할당/해제 오버헤드를 제거한다. 매 이터레이션의 지연 시간이 마이크로초 단위로 줄어든다.

  2. 태그 기반 메모리 관리: Sleep/Wake Up에서 가중치와 KV 캐시를 독립적으로 관리한다. 모델 교체 시 가중치만 교체하고 KV 캐시는 유지하는 등의 유연한 전략이 가능하다.

  3. 지연 통신 대기: AsyncIntermediateTensors__getattribute__ 오버라이드로, 텐서 사용 시점까지 통신 대기를 미룬다. PP에서 연산-통신 오버랩을 자연스럽게 구현한다.

  4. DP rank 기반 GPU 매핑: Data Parallel과 Tensor/Pipeline Parallel이 결합된 복잡한 환경에서도 올바른 GPU 인덱스를 계산한다.

정리

GPU Worker는 vLLM의 실행 계층에서 가장 하드웨어에 가까운 컴포넌트이다. 메모리 프로파일링, 사전 할당 버퍼, 비동기 통신, Sleep 메커니즘 등 GPU 자원을 극한까지 활용하기 위한 최적화가 집약되어 있다.

댓글

관련 포스트

vLLM 의 다른글