본문으로 건너뛰기

[SGLang] CUDA Graphs: 커널 런칭 오버헤드 제거

들어가며

LLM 추론의 Decode 단계에서는 배치 내 각 시퀀스가 토큰 1개씩만 생성한다. 이때 실제 GPU 연산 시간보다 커널 런칭(launch) 오버헤드가 더 클 수 있다. CUDA Graph는 커널 실행 순서를 한 번 녹화(capture)한 뒤 재생(replay)하여 이 오버헤드를 제거한다.

이 글에서는 python/sglang/srt/model_executor/cuda_graph_runner.py를 중심으로 CudaGraphRunner의 설계를 분석한다.

Before/After: 커널 런칭 오버헤드

CUDA Graph 적용 전후의 차이를 그림으로 비교한다.

Without CUDA Graph:
┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
│Launch│ │Launch│ │Launch│ │Launch│ │Launch│
│ K1  │ │ K2  │ │ K3  │ │ K4  │ │ K5  │
├─────┤ ├─────┤ ├─────┤ ├─────┤ ├─────┤
│ Run │ │ Run │ │ Run │ │ Run │ │ Run │
│ K1  │ │ K2  │ │ K3  │ │ K4  │ │ K5  │
└─────┘ └─────┘ └─────┘ └─────┘ └─────┘
  ↑ CPU Launch 오버헤드가 누적

With CUDA Graph:
┌───────────────────────────────────────┐
│  Graph Replay (단일 커널 런칭)          │
│  ┌────┬────┬────┬────┬────┐          │
│  │ K1 │ K2 │ K3 │ K4 │ K5 │          │
│  └────┴────┴────┴────┴────┘          │
└───────────────────────────────────────┘
  ↑ CPU는 replay 1번만 호출

Decode 단계에서 수백 개의 소형 커널이 연속 실행되므로, 이 오버헤드 제거의 효과는 상당하다.

CudaGraphRunner 초기화

CudaGraphRunner는 ModelRunner 초기화 시 생성된다. 다양한 배치 크기에 대해 CUDA Graph를 사전 캡처한다.

class CudaGraphRunner:
    def __init__(self, model_runner: ModelRunner):
        self.model_runner = model_runner
        self.graphs = {}          # bs → CUDAGraph
        self.output_buffers = {}  # bs → output tensor

        # 캡처할 배치 크기 결정
        self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(
            model_runner, self.num_tokens_per_bs
        )

get_batch_sizes_to_capture()는 서버 설정의 cuda_graph_bs 목록을 기반으로 유효한 배치 크기를 필터링한다.

def get_batch_sizes_to_capture(model_runner, num_tokens_per_bs=1):
    capture_bs = server_args.cuda_graph_bs
    num_max_requests = model_runner.req_to_token_pool.size

    # DP Attention, Two-Batch Overlap 등의 배수 조건
    mul_base = 1
    if server_args.enable_two_batch_overlap:
        mul_base *= 2
    if require_gathered_buffer(server_args):
        mul_base *= get_attention_tp_size()

    # 배수 조건과 최대 배치 크기로 필터링
    capture_bs = [bs for bs in capture_bs if bs * num_tokens_per_bs % mul_base == 0]
    capture_bs = [bs for bs in capture_bs if bs <= num_max_requests]
    return list(sorted(set(capture_bs))), compile_bs

DecodeInputBuffers: 입력 버퍼

CUDA Graph는 고정 메모리 주소를 요구한다. DecodeInputBuffers가 모든 배치 크기에 대해 재사용되는 고정 버퍼를 제공한다.

@dataclass
class DecodeInputBuffers(ForwardInputBuffers):
    input_ids: torch.Tensor
    req_pool_indices: torch.Tensor
    seq_lens: torch.Tensor
    seq_lens_cpu: torch.Tensor
    out_cache_loc: torch.Tensor
    positions: torch.Tensor
    num_token_non_padded: torch.Tensor
    next_token_logits_buffer: torch.Tensor
    ...

create() 클래스메서드에서 최대 배치 크기로 한 번 할당한다.

@classmethod
def create(cls, *, device, max_bs, max_num_token, hidden_size, vocab_size, ...):
    with torch.device(device):
        input_ids = torch.zeros((max_num_token,), dtype=torch.int64)
        seq_lens = torch.full((max_bs,), seq_len_fill_value, dtype=torch.int32)
        positions = torch.zeros((max_num_token,), dtype=torch.int64)
        next_token_logits_buffer = torch.zeros(
            (max_num_token, vocab_size), dtype=torch.float,
        )

그래프 캡처: capture()

capture() 메서드는 모든 배치 크기에 대해 CUDA Graph를 캡처한다. 큰 배치 크기부터 역순으로 캡처하여 메모리 재사용을 최적화한다.

def capture(self) -> None:
    def _capture_one_stream(stream_idx=None):
        # 역순으로 캡처 (큰 배치 → 작은 배치)
        capture_range = reversed(self.capture_bs)
        for bs in capture_range:
            with patch_model(
                self.model_runner.model,
                bs in self.compile_bs,  # torch.compile 적용 여부
                num_tokens=bs * self.num_tokens_per_bs,
                tp_group=self.model_runner.tp_group,
            ) as forward:
                graph, output_buffers = self.capture_one_batch_size(bs, forward)
                self.graphs[bs] = graph
                self.output_buffers[bs] = output_buffers

    with freeze_gc(self.model_runner.server_args.enable_cudagraph_gc):
        with graph_capture() as graph_capture_context:
            self.stream = graph_capture_context.stream
            _capture_one_stream()

freeze_gc()는 캡처 중 GC(가비지 컬렉션)를 동결하여 안정성을 보장한다.

@contextmanager
def freeze_gc(enable_cudagraph_gc):
    gc.collect()
    should_freeze = not enable_cudagraph_gc
    if should_freeze:
        gc.freeze()
    try:
        yield
    finally:
        if should_freeze:
            gc.unfreeze()

단일 배치 크기 캡처

capture_one_batch_size()에서 실제 CUDA Graph 녹화가 수행된다.

def capture_one_batch_size(self, bs, forward, stream_idx=None):
    graph = self._create_device_graph()
    num_tokens = bs * self.num_tokens_per_bs

    # 입력 버퍼 슬라이싱 (고정 주소)
    input_ids = buffers.input_ids[:num_tokens]
    seq_lens = buffers.seq_lens[:bs]
    out_cache_loc = buffers.out_cache_loc[:num_tokens]
    positions = buffers.positions[:num_tokens]
    ...

그래프 캡처는 _capture_graph()에서 수행된다.

def _capture_graph(self, graph, pool, stream, run_once_fn):
    with self.device_module.graph(cuda_graph=graph, pool=pool, stream=stream):
        out = run_once_fn()
    return out

글로벌 메모리 풀

모든 CudaGraphRunner 인스턴스는 단일 메모리 풀을 공유한다. 이는 Draft Worker와 Target Worker가 동시에 존재하는 Speculative Decoding 환경에서 중요하다.

# Reuse this memory pool across all cuda graph runners.
global_graph_memory_pool = None

def get_global_graph_memory_pool():
    return global_graph_memory_pool

def set_global_graph_memory_pool(val):
    global global_graph_memory_pool
    global_graph_memory_pool = val

그래프 재생: can_run과 replay

실행 시 can_run()으로 그래프 실행 가능 여부를 확인하고, replay()로 재생한다.

def can_run(self, forward_batch: ForwardBatch):
    # 토큰 임베딩 오버라이드가 있으면 그래프 사용 불가
    if forward_batch.replace_embeds is not None:
        return False

    if self.require_mlp_tp_gather:
        cuda_graph_bs = max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
    else:
        cuda_graph_bs = forward_batch.batch_size

    is_bs_supported = (
        graph_key in self.graphs
        if self.disable_padding
        else cuda_graph_bs <= self.max_bs
    )

패딩이 활성화된 경우(기본값), 요청된 배치 크기보다 크거나 같은 가장 가까운 캡처된 크기를 사용한다.

torch.compile 통합

특정 배치 크기에 대해 torch.compile을 적용한 후 CUDA Graph를 캡처할 수 있다.

with patch_model(
    self.model_runner.model,
    bs in self.compile_bs,  # compile 대상 여부
    num_tokens=bs * self.num_tokens_per_bs,
    tp_group=self.model_runner.tp_group,
) as forward:
    graph, output_buffers = self.capture_one_batch_size(bs, forward)

set_torch_compile_config()에서 Inductor 최적화를 설정한다.

def set_torch_compile_config():
    torch._inductor.config.coordinate_descent_tuning = True
    torch._inductor.config.triton.unique_kernel_names = True
    torch._inductor.config.fx_graph_cache = True

설계 근거: 왜 배치 크기별로 캡처하는가

CUDA Graph는 실행 흐름의 고정 스냅샷이므로, 배치 크기가 달라지면 별도 그래프가 필요하다. SGLang은 자주 사용되는 배치 크기만 사전 캡처하고, 패딩으로 중간 크기를 처리한다.

전략 장점 단점
패딩 활성화 (기본) 적은 그래프 수, 빠른 초기화 패딩으로 인한 약간의 낭비
패딩 비활성화 정확한 연산 많은 그래프 수, 느린 초기화

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글