본문으로 건너뛰기

[SGLang] Efficient Vision Sampling: 이미지 토큰 압축

들어가며

비디오 입력은 수천 개의 이미지 토큰을 생성하여 LLM의 컨텍스트를 빠르게 소진한다. Efficient Vision Sampling(EVS)은 인접 프레임 간 유사한 토큰을 제거하여 비디오 토큰 수를 줄이는 기법이다. SGLang은 EVS를 모듈화하여 다양한 비디오 모델에 적용할 수 있도록 구현했다.

이 글에서는 python/sglang/srt/multimodal/evs/ 디렉토리를 분석한다.

EVS 동작 원리

비디오 임베딩 (T frames × H × W tokens)
    │
    ▼
┌───────────────────────────────────────────┐
│ compute_retention_mask()                   │
│                                           │
│ Frame 0: ████████████ (항상 전체 보존)      │
│ Frame 1: ██████░░░░░░ (유사 토큰 제거)      │
│ Frame 2: ████████░░░░ (변화 큰 영역 보존)   │
│ Frame 3: ██░░░░░░░░░░ (거의 동일 → 많이 제거)│
│                                           │
│ ░ = 제거된 토큰 (인접 프레임과 유사)         │
│ █ = 보존된 토큰 (dissimilarity 높음)        │
└───────────────────┬───────────────────────┘
                    │
                    ▼
    압축된 비디오 임베딩 (retained tokens only)

핵심 코드 분석

evs_core.py: 핵심 알고리즘

Retention Mask 계산

프레임 간 cosine dissimilarity를 기준으로 보존할 토큰을 결정한다.

def compute_retention_mask(video_embeds, video_size_thw, spatial_merge_size, q):
    T, H, W = map(int, video_size_thw)

    video_embeds = video_embeds.reshape(
        T, H // spatial_merge_size, W // spatial_merge_size,
        video_embeds.size(-1),
    )
    tokens_per_frame = (H // spatial_merge_size) * (W // spatial_merge_size)

    # 인접 프레임 간 cosine similarity
    similarity = torch.nn.functional.cosine_similarity(
        video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
    )
    dissimilarity = 1 - similarity

    # 첫 프레임은 항상 전체 보존 (255로 설정)
    dissimilarity = torch.cat(
        [255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], dim=0
    )

    # Top-K 방식으로 보존할 토큰 선택
    dissimilarity_flat = dissimilarity.view(-1)
    order = torch.argsort(dissimilarity_flat, descending=True, stable=True)
    retain_num_tokens = compute_retained_tokens_count(
        tokens_per_frame, T, q
    )
    topk_indices = order[:retain_num_tokens]

    retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
    retention_mask[topk_indices] = True
    return retention_mask.view(-1)

알고리즘의 핵심은 다음과 같다.

  1. 연속 프레임의 같은 위치 토큰 간 cosine similarity를 계산한다.
  2. 1 - similarity로 dissimilarity를 구한다.
  3. 첫 프레임의 dissimilarity를 255로 설정하여 항상 보존한다.
  4. dissimilarity가 높은 순으로 retain_num_tokens개를 선택한다.

보존 토큰 수 계산

def compute_retained_tokens_count(tokens_per_frame, num_frames, q):
    total_tokens = tokens_per_frame * num_frames
    evs_num_tokens = int(total_tokens * (1 - q))
    min_num_tokens = tokens_per_frame  # 최소 1프레임 분량
    return max(min_num_tokens, evs_num_tokens)

pruning rate q=0.5이면 전체 토큰의 50%를 제거한다. 단, 최소 1프레임 분량은 항상 보존한다.

evs_module.py: 모델 통합 모듈

EVS Mixin 클래스

class EVS(torch.nn.Module, ABC):
    @staticmethod
    @abstractmethod
    def create_evs_config(config: PretrainedConfig) -> EVSConfig:
        raise NotImplementedError

    @abstractmethod
    def get_video_feature(self, items: list[MultimodalDataItem]) -> torch.Tensor:
        raise NotImplementedError

    def __init__(self, config, *args, **kwargs):
        super().__init__()
        self.evs_config = self.create_evs_config(config)
        self.evs_enabled = self.evs_config.video_pruning_rate > 0.0
        if self.evs_enabled:
            self.get_video_feature = self.evs_video

EVS를 모델 클래스에 mixin으로 추가하면, get_video_feature() 메서드가 자동으로 EVS 적용 버전으로 교체된다.

EVS 적용 실행

def evs_video(self, items):
    q = self.evs_config.video_pruning_rate
    merge = self.evs_config.spatial_merge_size
    videos_features = self.original_get_video_feature([item])

    final_embeddings = []
    num_tokens_per_frame = []

    sizes = [(t * h * w // merge**2) for t, h, w in item.thw_grids]
    for single_video, video_size_thw in zip(
        videos_features.split(sizes), item.thw_grids
    ):
        retention_mask = compute_retention_mask(
            single_video, video_size_thw, merge, q
        )
        preserved = single_video[retention_mask]
        final_embeddings.append(preserved)

        tokens_per_frame = (
            retention_mask.reshape(num_frames, -1).sum(dim=-1).tolist()
        )
        num_tokens_per_frame.extend(tokens_per_frame)

    return EVSEmbeddingResult(
        embedding=torch.cat(final_embeddings),
        num_tokens_per_frame=num_tokens_per_frame,
    )

프레임별로 실제 보존된 토큰 수(num_tokens_per_frame)를 기록한다. 이 정보는 이후 input_ids의 플레이스홀더 재조정에 사용된다.

input_ids 재조정

EVS 후 프레임별 토큰 수가 달라지므로 input_ids의 플레이스홀더를 재배치해야 한다.

def replace_offsets_with_tokens_per_frame(
    pre_chunked_input_ids, num_tokens_per_frame,
    frame_offsets_inclusive, filler_token_id
):
    # 예시:
    # input_ids = [1, 0, 0, 4, 5, 0, 0, 0, 9, 10, 0, 0, 12, 13]
    # offsets   = [(1,2), (5,7), (10,11)]
    # new_tpf   = [1, 4, 2]
    # result    = [1, 0, 4, 5, 0, 0, 0, 0, 9, 10, 0, 0, 12, 13]

    cursor = 0
    final = []
    for (start, end), num_tokens in zip(
        frame_offsets_inclusive, num_tokens_per_frame
    ):
        final.extend(ids[cursor:start])
        final.extend([filler_token_id] * num_tokens)
        cursor = end + 1
    final.extend(ids[frame_offsets_inclusive[-1][1] + 1:])
    return final

기존 오프셋 구간의 filler 토큰 수를 EVS 결과에 맞게 조정한다.

evs_processor.py: 프로세서 통합

class EVSProcessor:
    def __init__(self, hf_config, config_to_evs_model):
        evs_model = config_to_evs_model.get(hf_config.__class__)
        evs_config = evs_model.create_evs_config(hf_config)
        if evs_config.video_pruning_rate > 0.0:
            self.evs_config = evs_config

    def static_size_data_items(self, *, frames_per_video, num_images, rows, cols):
        frame_num_tokens = rows * cols

        if self.evs_config is None:
            tpf = [[frame_num_tokens] * num_frames for num_frames in frames_per_video]
            return _non_evs_data_items, tpf

        # EVS 활성화 시: 줄어든 토큰 수로 플레이스홀더 생성
        tpf = [
            tokens_per_frame(
                q=self.evs_config.video_pruning_rate,
                num_frames=num_frames,
                frame_num_tokens=frame_num_tokens,
            )
            for num_frames in frames_per_video
        ]
        return create_evs_data_items, tpf

EVS가 비활성화되면 일반 MultimodalDataItem을, 활성화되면 VideoEVSDataItem을 생성한다. 사전에 줄어든 토큰 수로 플레이스홀더를 할당하여 forward 시 메타데이터 불일치를 방지한다.

EVS 효과 시뮬레이션

32프레임 비디오, 프레임당 256토큰:

q=0.0 (EVS 비활성): 32 × 256 = 8,192 토큰
q=0.3 (30% 제거):   max(256, 8192×0.7) = 5,734 토큰 (30% 감소)
q=0.5 (50% 제거):   max(256, 8192×0.5) = 4,096 토큰 (50% 감소)
q=0.7 (70% 제거):   max(256, 8192×0.3) = 2,457 토큰 (70% 감소)

첫 프레임 (256토큰)은 항상 전체 보존
정적 장면 → 더 많이 제거
동적 장면 → 더 적게 제거

설계 근거

설계 선택 이유
Cosine dissimilarity 기반 인접 프레임 간 의미적 차이를 측정하는 가장 자연스러운 메트릭
첫 프레임 전체 보존 참조 프레임 없이는 dissimilarity를 계산할 수 없으므로 전체 보존
Mixin 패턴 기존 모델 코드 수정 없이 EVS를 선택적으로 추가 가능
사전 플레이스홀더 조정 forward 전에 input_ids 길이를 맞춰 배치 메타데이터 정합성 유지
q 파라미터 0~1 사이 단일 값으로 정확도-효율 트레이드오프 제어

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글