본문으로 건너뛰기

[SGLang] ViT CUDA Graph: Vision Encoder 가속

들어가며

VLM의 Vision Encoder(ViT)는 이미지 패치를 임베딩으로 변환하는 반복적 연산이다. 동일한 입력 크기가 반복될 때 CUDA Graph로 캡처하면 커널 런치 오버헤드를 제거할 수 있다. SGLang의 ViTCudaGraphRunner는 이 최적화를 Qwen2.5-VL, Qwen3-VL 등에 적용한다.

이 글에서는 python/sglang/srt/multimodal/vit_cuda_graph_runner.py를 분석한다.

CUDA Graph 캡처/재생 구조도

첫 번째 호출 (seq_len=1024):
┌──────────────────────────────────────────┐
│ create_graph(x_3d, cu_seqlens, ...)      │
│  1. workspace 할당 (block_input, block_ws)│
│  2. cu_seqlens 버퍼 생성                  │
│  3. rotary pos embedding 복사             │
│  4. torch.cuda.graph(graph) 컨텍스트에서  │
│     ViT blocks + merger 실행 캡처         │
│  5. graph_key(=1024) → graphs dict에 저장 │
└──────────────────────────────────────────┘

두 번째 호출 (seq_len=1024):
┌──────────────────────────────────────────┐
│ replay(graph_key=1024, x_3d, ...)        │
│  1. rotary workspace 업데이트             │
│  2. block_input.copy_(x_3d)              │
│  3. graph.replay()  ← 커널 재발사 없음!   │
│  4. block_output 반환                     │
└──────────────────────────────────────────┘

핵심 코드 분석

초기화: 모델 특성 탐지

ViTCudaGraphRunner는 ViT 모델의 특성을 자동으로 감지한다.

class ViTCudaGraphRunner:
    def __init__(self, vit: nn.Module):
        self.vit = vit

        # graph_key -> buffers/graphs
        self.block_input: Dict[Hashable, torch.Tensor] = {}
        self.block_ws: Dict[Hashable, torch.Tensor] = {}
        self.block_graphs: Dict[Hashable, torch.cuda.CUDAGraph] = {}
        self.block_output: Dict[Hashable, torch.Tensor] = {}

        # Qwen2.5-VL: windowed attention
        self._fullatt_block_indexes = set(
            getattr(vit, "fullatt_block_indexes", ())
        )
        # Qwen3-VL: deepstack
        self._deepstack_visual_indexes = list(
            getattr(vit, "deepstack_visual_indexes", []) or []
        )
        self._deepstack_merger_list = getattr(vit, "deepstack_merger_list", None)

세 가지 모델 변형을 지원한다.

  • 기본 ViT: blocks + merger
  • Qwen2.5-VL: windowed attention + full attention 블록 혼합
  • Qwen3-VL: deepstack merger 추가

Graph Key: 시퀀스 길이 기반

def _get_graph_key(self, x_3d: torch.Tensor) -> int:
    return x_3d.shape[0]  # 시퀀스 길이

CUDA Graph는 입력 형상이 동일해야 재생할 수 있다. 시퀀스 길이를 key로 사용하여 같은 크기의 이미지가 재방문할 때 그래프를 재활용한다.

Workspace 할당

그래프 캡처 전에 입력, 어텐션 워크스페이스, 출력 버퍼를 사전 할당한다.

def create_graph(self, x_3d, cu_seqlens, cu_window_seqlens, position_embeddings, ...):
    graph_key = self._get_graph_key(x_3d)
    if graph_key in self.block_graphs:
        return graph_key

    attn_module = vit.blocks[0].attn
    num_heads = attn_module.num_attention_heads_per_partition
    attn_head_dim = attn_module.head_size

    self.block_output[graph_key] = torch.empty_like(x_3d).contiguous()
    self.block_input[graph_key] = torch.empty_like(x_3d).contiguous()
    self.block_ws[graph_key] = torch.empty(
        graph_key, num_heads, attn_head_dim,
        device=self.device, dtype=self.dtype,
    )

어텐션 워크스페이스(block_ws)는 (seq_len, num_heads, head_dim) 형상으로, 각 블록의 중간 결과를 저장한다.

그래프 캡처: 블록 순회

def _create_graph(self, graph_key, position_embeddings=None, ...):
    graph = torch.cuda.CUDAGraph()

    with capture_ctx, torch.cuda.graph(graph):
        y = None
        deepstack_outs = []
        for layer_num, blk in enumerate(vit.blocks):
            # Qwen2.5-VL: 블록별 attention 유형 결정
            if self._fullatt_block_indexes:
                if layer_num in vit.fullatt_block_indexes:
                    cu_seqlens_now = cu_full
                else:
                    cu_seqlens_now = cu_window
            # ...
            if layer_num == 0:
                y = blk(self.block_input[graph_key], cu_seqlens=cu_seq_len_ws, ...)
            else:
                y = blk(y, cu_seqlens=cu_seq_len_ws, ...)

            # Qwen3-VL: deepstack 처리
            if layer_num in self._deepstack_visual_indexes:
                deepstack_out = self._deepstack_merger_list[idx](y)
                deepstack_outs.append(deepstack_out)

        main_out = vit.merger(y)
        if deepstack_outs:
            self.block_output[graph_key] = torch.cat(
                [main_out] + deepstack_outs, dim=1
            )

Qwen2.5-VL은 특정 블록에서 full attention, 나머지에서 windowed attention을 사용한다. 그래프 캡처 시 이 분기도 함께 캡처된다.

Rotary Position Embedding 관리

회전 위치 임베딩은 이미지마다 달라지므로 별도 워크스페이스에서 관리한다.

def _ensure_sin_cos_ws(self, seq_len, head_dim):
    if self.sin_cos_ws is None:
        max_shape = self.max_context_len or seq_len
        cos_ws = torch.empty(max_shape, head_dim, ...)
        sin_ws = torch.empty(max_shape, head_dim, ...)
        self.sin_cos_ws = (cos_ws, sin_ws)
    elif self.sin_cos_ws[0].size(0) < seq_len:
        max_shape = max(self.sin_cos_ws[0].size(0) * 2, seq_len)
        # 2배 증가로 재할당 빈도 감소

2배 증가 전략으로 재할당 횟수를 O(log N)으로 줄인다.

재생: 입력 복사 + replay

def replay(self, graph_key, x_3d, position_embeddings=None, ...):
    if position_embeddings is not None:
        used_cos_ws = self.sin_cos_ws[0][:graph_key, :]
        used_sin_ws = self.sin_cos_ws[1][:graph_key, :]
        used_cos_ws.copy_(position_embeddings[0])
        used_sin_ws.copy_(position_embeddings[1])

    self.block_input[graph_key].copy_(x_3d)
    self.block_graphs[graph_key].replay()

    out = self.block_output[graph_key]
    if output_indices is not None:
        out = out.index_select(0, output_indices)
    return out

replay()는 입력 데이터만 복사하고 캡처된 커널 시퀀스를 재실행한다. output_indices는 Qwen2.5-VL의 windowed attention permutation 역변환에 사용된다.

진입점: run()

def run(self, x, cu_seqlens, cu_window_seqlens, position_embeddings, ...):
    x_3d = x.unsqueeze(1)  # [seq_len, hidden] -> [S, 1, H]
    graph_key = self._get_graph_key(x_3d)

    if graph_key not in self.block_graphs:
        self.create_graph(x_3d=x_3d, ...)

    return self.replay(graph_key=graph_key, x_3d=x_3d, ...)

첫 호출 시 캡처, 이후 동일 크기면 재생한다.

성능 영향

CUDA Graph 없음:
  이미지 인코딩 = N블록 × (커널 런치 + 계산)
  커널 런치 오버헤드 ≈ 5~15μs × 수십 개 커널 × N블록

CUDA Graph 적용:
  첫 호출: 캡처 비용 (1회)
  이후: 단일 replay() = 전체 블록 체인 한 번에 실행
  커널 런치 오버헤드 ≈ 0

ViT 블록 수가 많고(24~48블록) 이미지 크기가 반복적일수록 효과가 크다.

설계 근거

설계 선택 이유
seq_len 기반 graph_key 동일 해상도 이미지가 반복될 때 그래프 재활용
별도 sin/cos workspace 위치 임베딩은 이미지마다 달라 매번 복사 필요
2배 증가 워크스페이스 재할당 빈도 O(log N), amortized 비용 절감
contiguous() 강제 CUDA Graph는 연속 메모리 레이아웃을 요구
TP ca_comm.capture() 텐서 병렬의 커뮤니케이션도 그래프에 포함

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글