[vLLM] CUDA Graphs: 커널 런칭 오버헤드 제거
들어가며
LLM 디코드 단계에서는 배치 내 각 요청이 단 하나의 토큰만 생성한다. 이 경우 GPU 연산 자체보다 수십~수백 개의 CUDA 커널을 런칭하는 CPU 오버헤드가 병목이 된다. CUDA Graph는 커널 호출 시퀀스를 한 번 녹화한 뒤 단일 API 호출로 재생하여 이 오버헤드를 제거한다.
소스 경로: vllm/compilation/cuda_graph.py, vllm/v1/worker/gpu/model_runner.py
공식 문서
vLLM 공식 문서: CUDA Graphs
핵심 구조/코드 분석
CUDAGraphWrapper
# vllm/compilation/cuda_graph.py
class CUDAGraphWrapper:
"""Wraps a runnable to add CUDA graph capturing and replaying ability."""
def __init__(self, runnable, vllm_config, runtime_mode, ...):
self.runnable = runnable
self.runtime_mode = runtime_mode # FULL or PIECEWISE
self.graph_pool = current_platform.get_global_graph_pool()
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
CUDAGraphWrapper는 모델이나 모델의 일부(piecewise 조각)를 감싸서 CUDA 그래프 캡처/재생 기능을 투명하게 추가한다. BatchDescriptor를 키로 사용하여 배치 크기별로 그래프를 캐싱한다.
디스패치 로직
def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
if (cudagraph_runtime_mode == CUDAGraphMode.NONE
or cudagraph_runtime_mode != self.runtime_mode):
# 프로파일 런, 워밍업, 또는 모드 불일치
return self.runnable(*args, **kwargs)
# CUDA 그래프 캡처 또는 재생
...
호출 시 ForwardContext에서 현재 배치의 CUDA 그래프 모드를 확인한다. 모드가 NONE이거나 래퍼의 모드와 불일치하면 일반 실행으로 폴백한다.
CUDAGraphEntry
@dataclasses.dataclass
class CUDAGraphEntry:
batch_descriptor: BatchDescriptor
cudagraph: torch.cuda.CUDAGraph | None = None
output: Any | None = None
input_addresses: list[int] | None = None # 디버그용
각 배치 크기에 대해 캡처된 그래프, 출력 텐서, 그리고 디버그 모드에서는 입력 주소까지 저장한다. 재생 시 입력 주소가 캡처 시와 동일한지 검증하여 잠재적 버그를 감지한다.
GPUModelRunner에서의 캡처
# vllm/v1/worker/gpu/model_runner.py
@torch.inference_mode()
def capture_model(self) -> int:
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
self.cudagraph_manager.capture(
self.model, self.model_state, self.input_buffers,
self.intermediate_tensors, self.block_tables,
self.attn_groups, self.kv_cache_config,
has_lora=self.lora_config is not None,
use_aux_hidden_state_outputs=self.use_aux_hidden_state_outputs,
)
if self.speculator is not None:
self.speculator.capture_model()
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
capture_model()은 워밍업 단계에서 호출되어 다양한 배치 크기에 대해 CUDA 그래프를 미리 녹화한다. Speculator(투기적 디코딩 모델)도 별도로 캡처한다. 캡처에 사용된 GPU 메모리를 측정하여 로깅한다.
FULL vs PIECEWISE 모드
vLLM은 두 가지 CUDA 그래프 모드를 지원한다:
- FULL: 전체 모델 포워드 패스를 하나의 CUDA 그래프로 캡처. 오버헤드가 가장 적지만, 그래프 안에서 Python 코드를 실행할 수 없다.
- PIECEWISE: 모델을 여러 조각으로 분할하고 각각을 별도 CUDA 그래프로 캡처. 조각 사이에 Python 코드(KV 커넥터 동기화 등)를 실행할 수 있다.
# model_runner.py
self.cudagraph_manager = ModelCudaGraphManager(
self.vllm_config, self.device,
self.compilation_config.cudagraph_mode, # FULL, PIECEWISE, NONE
decode_query_len=self.decode_query_len,
)
왜 이 설계인가
-
배치 크기별 캐싱: LLM 서빙에서 배치 크기는 동적으로 변한다. 모든 가능한 크기에 대해 그래프를 캡처하면 메모리가 부족하므로, 패딩을 통해 제한된 수의 크기만 캡처한다.
-
WeakRef 출력:
CUDAGraphOptions.weak_ref_output를 통해 출력 텐서를 약한 참조로 관리한다. 이는 CUDA 그래프의 정적 출력 버퍼가 GC에 의해 회수되는 것을 방지하면서도 메모리 누수를 막는다. -
글로벌 그래프 풀:
current_platform.get_global_graph_pool()로 모든 CUDA 그래프가 동일한 메모리 풀을 공유한다. 이는 NCCL 통신과의 호환성을 보장하고 메모리 단편화를 줄인다. -
디버그 모드:
VLLM_LOGGING_LEVEL=DEBUG에서는 매 재생마다 입력 텐서의 메모리 주소를 검증한다. 주소가 변경되면 CUDA 그래프 재생이 올바르지 않으므로, 이를 조기에 감지한다.
참고
관련 포스트
vLLM 의 다른글
- 이전글 [vLLM] torch.compile 통합: PyTorch 컴파일러
- 현재글 : [vLLM] CUDA Graphs: 커널 런칭 오버헤드 제거
- 다음글 [vLLM] Medusa: 다중 예측 헤드 투기적 디코딩
댓글