[SGLang] EAGLE CUDA Graph: 드래프트 모델 가속
들어가며
EAGLE의 드래프트 모델은 매 step마다 작은 모델을 반복 forward하므로, 커널 launch 오버헤드가 전체 드래프팅 시간의 상당 부분을 차지할 수 있다. CUDA Graph는 커널 시퀀스를 한 번 캡처하고 이후에는 재생만으로 실행하여 launch 오버헤드를 거의 제거한다. SGLang의 EAGLEDraftCudaGraphRunner는 드래프트 모델의 다단계 forward를 배치 크기별로 캡처하여 재사용한다.
구조도
┌──────────────────────────────────────────────────────┐
│ EAGLEDraftCudaGraphRunner │
│ │
│ Capture Phase (초기화 시 1회): │
│ ┌─────────────────────────────────────────────────┐ │
│ │ for bs in capture_bs: │ │
│ │ graph = CUDAGraph() │ │
│ │ with graph.capture(): │ │
│ │ eagle_worker.draft_forward(forward_batch) │ │
│ │ graphs[bs] = graph │ │
│ │ output_buffers[bs] = (parent_list, │ │
│ │ top_scores_index, │ │
│ │ draft_tokens) │ │
│ └─────────────────────────────────────────────────┘ │
│ │
│ Replay Phase (매 inference): │
│ ┌─────────────────────────────────────────────────┐ │
│ │ 1. Copy inputs → buffers │ │
│ │ 2. Pad to captured bs │ │
│ │ 3. graphs[bs].replay() │ │
│ │ 4. Trim output to raw_bs │ │
│ └─────────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────────┘
핵심 코드 분석
1. 입력 버퍼 정의
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py에서 CUDA graph의 입력 텐서를 사전 할당한다.
@dataclass
class EagleDraftInputBuffers(ForwardInputBuffers):
input_ids: torch.Tensor
req_pool_indices: torch.Tensor
out_cache_loc: torch.Tensor
positions: torch.Tensor
mrope_positions: torch.Tensor
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
extend_seq_lens: torch.Tensor
topk_p: torch.Tensor
topk_index: torch.Tensor
hidden_states: torch.Tensor
global_num_tokens_gpu: Optional[torch.Tensor]
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor]
모든 입력이 최대 배치 크기로 미리 할당되어, 런타임에 메모리 할당이 발생하지 않는다.
2. Runner 초기화
class EAGLEDraftCudaGraphRunner:
def __init__(self, eagle_worker):
self.model_runner = model_runner = eagle_worker.model_runner
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.topk = model_runner.server_args.speculative_eagle_topk
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.num_tokens_per_bs = self.topk
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
# Attention backend CUDA graph 상태 초기화
self.model_runner.draft_attn_backend.init_cuda_graph_state(
self.max_bs, self.max_num_token
)
num_tokens_per_bs = topk인 이유는 각 배치 요소가 top-k개의 토큰을 동시에 처리하기 때문이다. capture_bs는 캡처할 배치 크기 목록으로, get_batch_sizes_to_capture가 시스템 리소스에 맞게 결정한다.
3. 최대 크기 버퍼 할당
with torch.device(model_runner.device):
input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
out_cache_loc = torch.zeros(
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
)
hidden_states = torch.zeros(
(self.max_bs, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,
)
topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
out_cache_loc의 크기가 max_num_token * speculative_num_steps인 이유는 각 step의 모든 top-k 토큰에 대해 캐시 위치가 필요하기 때문이다.
4. 캡처: capture_one_batch_size
각 배치 크기에 대해 CUDA graph를 캡처한다.
def capture_one_batch_size(self, num_seqs, forward, stream_idx=0):
graph = self._create_graph()
spec_info = EagleDraftInput(
topk_p=topk_p, topk_index=topk_index,
hidden_states=hidden_states,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
forward_batch = ForwardBatch(
forward_mode=ForwardMode.DECODE,
batch_size=num_seqs,
spec_info=spec_info, ...
)
# Attention backend 메타데이터 캡처
self.model_runner.draft_attn_backend.init_forward_metadata_capture_cuda_graph(
forward_batch
)
def run_once():
output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states
ret = self.eagle_worker.draft_forward(forward_batch)
forward_batch.out_cache_loc = output_cache_loc_backup
forward_batch.spec_info.hidden_states = hidden_states_backup
return ret
self._capture_init(run_once) # 2회 워밍업
out = self._capture_graph(graph, get_global_graph_memory_pool(), stream, run_once)
set_global_graph_memory_pool(graph.pool())
return graph, out
run_once 내에서 out_cache_loc과 hidden_states를 백업/복원하는 이유는 draft_forward가 이 텐서를 in-place로 수정하기 때문이다. CUDA graph는 같은 메모리 주소를 재사용해야 하므로 원본을 보존한다.
5. 재생: replay
실제 inference 시 캡처된 graph를 재생한다.
def replay(self, forward_batch: ForwardBatch):
buffers = self.buffers
raw_bs = forward_batch.batch_size
raw_num_token = raw_bs * self.num_tokens_per_bs
# 패딩이 필요한 경우
if self.require_mlp_tp_gather:
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
max_batch_size = max_num_tokens // self.num_tokens_per_bs
index = bisect.bisect_left(self.capture_bs, max_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
buffers.seq_lens.fill_(self.seq_len_fill_value)
buffers.out_cache_loc.zero_()
buffers.positions.zero_()
# 실제 데이터 복사
buffers.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
buffers.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p)
buffers.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
buffers.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
# Attention 메타데이터 업데이트
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch, bs
)
# Graph 재생
self.graphs[bs].replay()
out = self.output_buffers[bs]
# 패딩 제거
if bs != raw_bs:
out = self._postprocess_output_to_raw_bs(out, raw_bs)
return out
bisect_left로 현재 배치 크기 이상인 최소 캡처 크기를 찾아 해당 graph를 사용한다. 패딩 영역은 seq_len_fill_value로 채워 attention이 무시하도록 한다.
6. 실행 가능 여부 판단
def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather:
cuda_graph_bs = max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
else:
cuda_graph_bs = forward_batch.batch_size
is_bs_supported = (
cuda_graph_bs in self.graphs
if self.disable_padding
else cuda_graph_bs <= self.max_bs
)
return is_bs_supported
disable_padding 모드에서는 정확히 캡처된 배치 크기만 허용하고, 기본 모드에서는 최대 배치 크기 이하면 패딩으로 처리한다.
CUDA Graph의 제약과 해결
| 제약 | SGLang의 해결 |
|---|---|
| 텐서 크기 고정 | 최대 크기로 사전 할당 + 패딩 |
| 동적 제어 흐름 불가 | draft_forward 내부를 deterministic하게 유지 |
| 메모리 주소 고정 | in-place 수정 텐서를 백업/복원 |
| 그래프 캡처 시간 | 서버 시작 시 1회만 수행 |
설계 근거
배치 크기별 캡처: 런타임 배치 크기가 가변적이므로, 여러 크기를 미리 캡처하고 bisect로 최적 크기를 빠르게 찾는다.
Global Graph Memory Pool: set_global_graph_memory_pool으로 모든 graph가 같은 memory pool을 공유하여 CUDA graph 간 메모리 중복 할당을 방지한다.
Attention Backend 연동: init_forward_metadata_capture_cuda_graph와 init_forward_metadata_replay_cuda_graph로 attention metadata도 graph에 포함시켜, metadata 계산 오버헤드까지 제거한다.
관련 포스트
참고
- SGLang EAGLE Draft CUDA Graph Runner 소스
- NVIDIA CUDA Programming Guide: CUDA Graphs
관련 포스트
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
- [sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서
- [sglang] NixlKVManager 성능 향상: 비동기 및 멀티스레드 KV 전송 도입
SGLang 의 다른글
- 이전글 [SGLang] DFlash: Flash 기반 고속 드래프팅
- 현재글 : [SGLang] EAGLE CUDA Graph: 드래프트 모델 가속
- 다음글 [SGLang] Tree Search & Verification: 트리 기반 추측과 검증
댓글