[SGLang] ViT CUDA Graph: Vision Encoder 가속
들어가며
VLM의 Vision Encoder(ViT)는 이미지 패치를 임베딩으로 변환하는 반복적 연산이다. 동일한 입력 크기가 반복될 때 CUDA Graph로 캡처하면 커널 런치 오버헤드를 제거할 수 있다. SGLang의 ViTCudaGraphRunner는 이 최적화를 Qwen2.5-VL, Qwen3-VL 등에 적용한다.
이 글에서는 python/sglang/srt/multimodal/vit_cuda_graph_runner.py를 분석한다.
CUDA Graph 캡처/재생 구조도
첫 번째 호출 (seq_len=1024):
┌──────────────────────────────────────────┐
│ create_graph(x_3d, cu_seqlens, ...) │
│ 1. workspace 할당 (block_input, block_ws)│
│ 2. cu_seqlens 버퍼 생성 │
│ 3. rotary pos embedding 복사 │
│ 4. torch.cuda.graph(graph) 컨텍스트에서 │
│ ViT blocks + merger 실행 캡처 │
│ 5. graph_key(=1024) → graphs dict에 저장 │
└──────────────────────────────────────────┘
두 번째 호출 (seq_len=1024):
┌──────────────────────────────────────────┐
│ replay(graph_key=1024, x_3d, ...) │
│ 1. rotary workspace 업데이트 │
│ 2. block_input.copy_(x_3d) │
│ 3. graph.replay() ← 커널 재발사 없음! │
│ 4. block_output 반환 │
└──────────────────────────────────────────┘
핵심 코드 분석
초기화: 모델 특성 탐지
ViTCudaGraphRunner는 ViT 모델의 특성을 자동으로 감지한다.
class ViTCudaGraphRunner:
def __init__(self, vit: nn.Module):
self.vit = vit
# graph_key -> buffers/graphs
self.block_input: Dict[Hashable, torch.Tensor] = {}
self.block_ws: Dict[Hashable, torch.Tensor] = {}
self.block_graphs: Dict[Hashable, torch.cuda.CUDAGraph] = {}
self.block_output: Dict[Hashable, torch.Tensor] = {}
# Qwen2.5-VL: windowed attention
self._fullatt_block_indexes = set(
getattr(vit, "fullatt_block_indexes", ())
)
# Qwen3-VL: deepstack
self._deepstack_visual_indexes = list(
getattr(vit, "deepstack_visual_indexes", []) or []
)
self._deepstack_merger_list = getattr(vit, "deepstack_merger_list", None)
세 가지 모델 변형을 지원한다.
- 기본 ViT: blocks + merger
- Qwen2.5-VL: windowed attention + full attention 블록 혼합
- Qwen3-VL: deepstack merger 추가
Graph Key: 시퀀스 길이 기반
def _get_graph_key(self, x_3d: torch.Tensor) -> int:
return x_3d.shape[0] # 시퀀스 길이
CUDA Graph는 입력 형상이 동일해야 재생할 수 있다. 시퀀스 길이를 key로 사용하여 같은 크기의 이미지가 재방문할 때 그래프를 재활용한다.
Workspace 할당
그래프 캡처 전에 입력, 어텐션 워크스페이스, 출력 버퍼를 사전 할당한다.
def create_graph(self, x_3d, cu_seqlens, cu_window_seqlens, position_embeddings, ...):
graph_key = self._get_graph_key(x_3d)
if graph_key in self.block_graphs:
return graph_key
attn_module = vit.blocks[0].attn
num_heads = attn_module.num_attention_heads_per_partition
attn_head_dim = attn_module.head_size
self.block_output[graph_key] = torch.empty_like(x_3d).contiguous()
self.block_input[graph_key] = torch.empty_like(x_3d).contiguous()
self.block_ws[graph_key] = torch.empty(
graph_key, num_heads, attn_head_dim,
device=self.device, dtype=self.dtype,
)
어텐션 워크스페이스(block_ws)는 (seq_len, num_heads, head_dim) 형상으로, 각 블록의 중간 결과를 저장한다.
그래프 캡처: 블록 순회
def _create_graph(self, graph_key, position_embeddings=None, ...):
graph = torch.cuda.CUDAGraph()
with capture_ctx, torch.cuda.graph(graph):
y = None
deepstack_outs = []
for layer_num, blk in enumerate(vit.blocks):
# Qwen2.5-VL: 블록별 attention 유형 결정
if self._fullatt_block_indexes:
if layer_num in vit.fullatt_block_indexes:
cu_seqlens_now = cu_full
else:
cu_seqlens_now = cu_window
# ...
if layer_num == 0:
y = blk(self.block_input[graph_key], cu_seqlens=cu_seq_len_ws, ...)
else:
y = blk(y, cu_seqlens=cu_seq_len_ws, ...)
# Qwen3-VL: deepstack 처리
if layer_num in self._deepstack_visual_indexes:
deepstack_out = self._deepstack_merger_list[idx](y)
deepstack_outs.append(deepstack_out)
main_out = vit.merger(y)
if deepstack_outs:
self.block_output[graph_key] = torch.cat(
[main_out] + deepstack_outs, dim=1
)
Qwen2.5-VL은 특정 블록에서 full attention, 나머지에서 windowed attention을 사용한다. 그래프 캡처 시 이 분기도 함께 캡처된다.
Rotary Position Embedding 관리
회전 위치 임베딩은 이미지마다 달라지므로 별도 워크스페이스에서 관리한다.
def _ensure_sin_cos_ws(self, seq_len, head_dim):
if self.sin_cos_ws is None:
max_shape = self.max_context_len or seq_len
cos_ws = torch.empty(max_shape, head_dim, ...)
sin_ws = torch.empty(max_shape, head_dim, ...)
self.sin_cos_ws = (cos_ws, sin_ws)
elif self.sin_cos_ws[0].size(0) < seq_len:
max_shape = max(self.sin_cos_ws[0].size(0) * 2, seq_len)
# 2배 증가로 재할당 빈도 감소
2배 증가 전략으로 재할당 횟수를 O(log N)으로 줄인다.
재생: 입력 복사 + replay
def replay(self, graph_key, x_3d, position_embeddings=None, ...):
if position_embeddings is not None:
used_cos_ws = self.sin_cos_ws[0][:graph_key, :]
used_sin_ws = self.sin_cos_ws[1][:graph_key, :]
used_cos_ws.copy_(position_embeddings[0])
used_sin_ws.copy_(position_embeddings[1])
self.block_input[graph_key].copy_(x_3d)
self.block_graphs[graph_key].replay()
out = self.block_output[graph_key]
if output_indices is not None:
out = out.index_select(0, output_indices)
return out
replay()는 입력 데이터만 복사하고 캡처된 커널 시퀀스를 재실행한다. output_indices는 Qwen2.5-VL의 windowed attention permutation 역변환에 사용된다.
진입점: run()
def run(self, x, cu_seqlens, cu_window_seqlens, position_embeddings, ...):
x_3d = x.unsqueeze(1) # [seq_len, hidden] -> [S, 1, H]
graph_key = self._get_graph_key(x_3d)
if graph_key not in self.block_graphs:
self.create_graph(x_3d=x_3d, ...)
return self.replay(graph_key=graph_key, x_3d=x_3d, ...)
첫 호출 시 캡처, 이후 동일 크기면 재생한다.
성능 영향
CUDA Graph 없음:
이미지 인코딩 = N블록 × (커널 런치 + 계산)
커널 런치 오버헤드 ≈ 5~15μs × 수십 개 커널 × N블록
CUDA Graph 적용:
첫 호출: 캡처 비용 (1회)
이후: 단일 replay() = 전체 블록 체인 한 번에 실행
커널 런치 오버헤드 ≈ 0
ViT 블록 수가 많고(24~48블록) 이미지 크기가 반복적일수록 효과가 크다.
설계 근거
| 설계 선택 | 이유 |
|---|---|
| seq_len 기반 graph_key | 동일 해상도 이미지가 반복될 때 그래프 재활용 |
| 별도 sin/cos workspace | 위치 임베딩은 이미지마다 달라 매번 복사 필요 |
| 2배 증가 워크스페이스 | 재할당 빈도 O(log N), amortized 비용 절감 |
| contiguous() 강제 | CUDA Graph는 연속 메모리 레이아웃을 요구 |
| TP ca_comm.capture() | 텐서 병렬의 커뮤니케이션도 그래프에 포함 |
관련 포스트
- Multimodal 처리 파이프라인 개요 - 전체 멀티모달 파이프라인 구조
- Vision-Language 모델: CLIP, InternVL, LLaVA - ViT를 사용하는 VLM 프로세서
- Efficient Vision Sampling: 이미지 토큰 압축 - ViT 출력 후 토큰 압축
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Audio 모델: Whisper, Qwen3-ASR, GLM-ASR 프로세서
- 현재글 : [SGLang] ViT CUDA Graph: Vision Encoder 가속
- 다음글 [SGLang] Efficient Vision Sampling: 이미지 토큰 압축
댓글