본문으로 건너뛰기

[SGLang] EAGLE CUDA Graph: 드래프트 모델 가속

들어가며

EAGLE의 드래프트 모델은 매 step마다 작은 모델을 반복 forward하므로, 커널 launch 오버헤드가 전체 드래프팅 시간의 상당 부분을 차지할 수 있다. CUDA Graph는 커널 시퀀스를 한 번 캡처하고 이후에는 재생만으로 실행하여 launch 오버헤드를 거의 제거한다. SGLang의 EAGLEDraftCudaGraphRunner는 드래프트 모델의 다단계 forward를 배치 크기별로 캡처하여 재사용한다.

구조도

┌──────────────────────────────────────────────────────┐
│           EAGLEDraftCudaGraphRunner                   │
│                                                      │
│  Capture Phase (초기화 시 1회):                       │
│  ┌─────────────────────────────────────────────────┐ │
│  │ for bs in capture_bs:                            │ │
│  │   graph = CUDAGraph()                           │ │
│  │   with graph.capture():                         │ │
│  │     eagle_worker.draft_forward(forward_batch)   │ │
│  │   graphs[bs] = graph                            │ │
│  │   output_buffers[bs] = (parent_list,            │ │
│  │                         top_scores_index,       │ │
│  │                         draft_tokens)            │ │
│  └─────────────────────────────────────────────────┘ │
│                                                      │
│  Replay Phase (매 inference):                        │
│  ┌─────────────────────────────────────────────────┐ │
│  │ 1. Copy inputs → buffers                        │ │
│  │ 2. Pad to captured bs                           │ │
│  │ 3. graphs[bs].replay()                          │ │
│  │ 4. Trim output to raw_bs                        │ │
│  └─────────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────────┘

핵심 코드 분석

1. 입력 버퍼 정의

python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py에서 CUDA graph의 입력 텐서를 사전 할당한다.

@dataclass
class EagleDraftInputBuffers(ForwardInputBuffers):
    input_ids: torch.Tensor
    req_pool_indices: torch.Tensor
    out_cache_loc: torch.Tensor
    positions: torch.Tensor
    mrope_positions: torch.Tensor
    seq_lens: torch.Tensor
    seq_lens_cpu: torch.Tensor
    extend_seq_lens: torch.Tensor
    topk_p: torch.Tensor
    topk_index: torch.Tensor
    hidden_states: torch.Tensor
    global_num_tokens_gpu: Optional[torch.Tensor]
    global_num_tokens_for_logprob_gpu: Optional[torch.Tensor]

모든 입력이 최대 배치 크기로 미리 할당되어, 런타임에 메모리 할당이 발생하지 않는다.

2. Runner 초기화

class EAGLEDraftCudaGraphRunner:
    def __init__(self, eagle_worker):
        self.model_runner = model_runner = eagle_worker.model_runner
        self.speculative_num_steps = model_runner.server_args.speculative_num_steps
        self.topk = model_runner.server_args.speculative_eagle_topk
        self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)

        self.num_tokens_per_bs = self.topk
        self.max_bs = max(self.capture_bs)
        self.max_num_token = self.max_bs * self.num_tokens_per_bs

        # Attention backend CUDA graph 상태 초기화
        self.model_runner.draft_attn_backend.init_cuda_graph_state(
            self.max_bs, self.max_num_token
        )

num_tokens_per_bs = topk인 이유는 각 배치 요소가 top-k개의 토큰을 동시에 처리하기 때문이다. capture_bs는 캡처할 배치 크기 목록으로, get_batch_sizes_to_capture가 시스템 리소스에 맞게 결정한다.

3. 최대 크기 버퍼 할당

with torch.device(model_runner.device):
    input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
    out_cache_loc = torch.zeros(
        (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
    )
    hidden_states = torch.zeros(
        (self.max_bs, self.model_runner.model_config.hidden_size),
        dtype=self.model_runner.dtype,
    )
    topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
    topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)

out_cache_loc의 크기가 max_num_token * speculative_num_steps인 이유는 각 step의 모든 top-k 토큰에 대해 캐시 위치가 필요하기 때문이다.

4. 캡처: capture_one_batch_size

각 배치 크기에 대해 CUDA graph를 캡처한다.

def capture_one_batch_size(self, num_seqs, forward, stream_idx=0):
    graph = self._create_graph()
    spec_info = EagleDraftInput(
        topk_p=topk_p, topk_index=topk_index,
        hidden_states=hidden_states,
        capture_hidden_mode=CaptureHiddenMode.LAST,
    )

    forward_batch = ForwardBatch(
        forward_mode=ForwardMode.DECODE,
        batch_size=num_seqs,
        spec_info=spec_info, ...
    )

    # Attention backend 메타데이터 캡처
    self.model_runner.draft_attn_backend.init_forward_metadata_capture_cuda_graph(
        forward_batch
    )

    def run_once():
        output_cache_loc_backup = forward_batch.out_cache_loc
        hidden_states_backup = forward_batch.spec_info.hidden_states
        ret = self.eagle_worker.draft_forward(forward_batch)
        forward_batch.out_cache_loc = output_cache_loc_backup
        forward_batch.spec_info.hidden_states = hidden_states_backup
        return ret

    self._capture_init(run_once)  # 2회 워밍업
    out = self._capture_graph(graph, get_global_graph_memory_pool(), stream, run_once)
    set_global_graph_memory_pool(graph.pool())
    return graph, out

run_once 내에서 out_cache_lochidden_states를 백업/복원하는 이유는 draft_forward가 이 텐서를 in-place로 수정하기 때문이다. CUDA graph는 같은 메모리 주소를 재사용해야 하므로 원본을 보존한다.

5. 재생: replay

실제 inference 시 캡처된 graph를 재생한다.

def replay(self, forward_batch: ForwardBatch):
    buffers = self.buffers
    raw_bs = forward_batch.batch_size
    raw_num_token = raw_bs * self.num_tokens_per_bs

    # 패딩이 필요한 경우
    if self.require_mlp_tp_gather:
        max_num_tokens = max(forward_batch.global_num_tokens_cpu)
        max_batch_size = max_num_tokens // self.num_tokens_per_bs
        index = bisect.bisect_left(self.capture_bs, max_batch_size)
    else:
        index = bisect.bisect_left(self.capture_bs, raw_bs)

    bs = self.capture_bs[index]
    if bs != raw_bs:
        buffers.seq_lens.fill_(self.seq_len_fill_value)
        buffers.out_cache_loc.zero_()
        buffers.positions.zero_()

    # 실제 데이터 복사
    buffers.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
    buffers.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p)
    buffers.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
    buffers.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)

    # Attention 메타데이터 업데이트
    self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
        forward_batch, bs
    )

    # Graph 재생
    self.graphs[bs].replay()
    out = self.output_buffers[bs]

    # 패딩 제거
    if bs != raw_bs:
        out = self._postprocess_output_to_raw_bs(out, raw_bs)
    return out

bisect_left로 현재 배치 크기 이상인 최소 캡처 크기를 찾아 해당 graph를 사용한다. 패딩 영역은 seq_len_fill_value로 채워 attention이 무시하도록 한다.

6. 실행 가능 여부 판단

def can_run(self, forward_batch: ForwardBatch):
    if self.require_mlp_tp_gather:
        cuda_graph_bs = max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
    else:
        cuda_graph_bs = forward_batch.batch_size

    is_bs_supported = (
        cuda_graph_bs in self.graphs
        if self.disable_padding
        else cuda_graph_bs <= self.max_bs
    )
    return is_bs_supported

disable_padding 모드에서는 정확히 캡처된 배치 크기만 허용하고, 기본 모드에서는 최대 배치 크기 이하면 패딩으로 처리한다.

CUDA Graph의 제약과 해결

제약 SGLang의 해결
텐서 크기 고정 최대 크기로 사전 할당 + 패딩
동적 제어 흐름 불가 draft_forward 내부를 deterministic하게 유지
메모리 주소 고정 in-place 수정 텐서를 백업/복원
그래프 캡처 시간 서버 시작 시 1회만 수행

설계 근거

배치 크기별 캡처: 런타임 배치 크기가 가변적이므로, 여러 크기를 미리 캡처하고 bisect로 최적 크기를 빠르게 찾는다.

Global Graph Memory Pool: set_global_graph_memory_pool으로 모든 graph가 같은 memory pool을 공유하여 CUDA graph 간 메모리 중복 할당을 방지한다.

Attention Backend 연동: init_forward_metadata_capture_cuda_graphinit_forward_metadata_replay_cuda_graph로 attention metadata도 graph에 포함시켜, metadata 계산 오버헤드까지 제거한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글