[SGLang] Multimodal Cache: Vision Encoder 출력 캐싱
들어가며
멀티모달 LLM에서 같은 이미지를 여러 요청이 참조하는 경우가 자주 있다. 예를 들어 하나의 이미지에 대해 "이 이미지를 설명해줘", "이 이미지에서 텍스트를 추출해줘" 같은 여러 질문을 던질 때, 매번 Vision Encoder를 실행하면 GPU 연산이 낭비된다.
SGLang의 MultimodalCache는 Vision Encoder의 출력 임베딩을 해시 기반으로 캐싱한다. 동일한 이미지 조합이 입력되면 인코딩을 건너뛰고 캐시된 임베딩을 재활용한다. python/sglang/srt/mem_cache/multimodal_cache.py를 분석한다.
구조도
MultimodalCache (ABC)
└── MultiModalStaticCache
├── max_size: int ← 바이트 단위 최대 크기
├── current_size: int ← 현재 사용 중인 바이트
└── mm_cache: OrderedDict[int, EmbeddingResult]
│
│ 키: combine_hashes([hash1, hash2, ...])
│ = hash(tuple(mm_hashes))
│
│ 값: EmbeddingResult
│ └── embedding: torch.Tensor
│
│ Eviction: LRU (OrderedDict.popitem(last=False))
│
│ 조회 → 히트 시 move_to_end (MRU로 승격)
│ 삽입 → 용량 초과 시 popitem(last=False) (LRU 제거)
└── 미스 시 None 반환 → Vision Encoder 실행
추상 인터페이스: MultimodalCache
MultimodalCache는 서버 레벨에서 멀티모달 임베딩을 관리하는 추상 클래스다. get, set, has, free, clear 다섯 가지 연산을 정의한다.
class MultimodalCache(abc.ABC):
@staticmethod
def combine_hashes(mm_hashes: List[int]) -> Optional[int]:
if not mm_hashes:
return None
return hash(tuple(mm_hashes))
@abc.abstractmethod
def get(self, mm_hashes: List[int],
combined_hash: Optional[int] = None) -> Optional[torch.Tensor]:
raise NotImplementedError()
@abc.abstractmethod
def set(self, mm_hash: int, embedding: torch.Tensor,
mm_embedding_allocator: BaseTokenToKVPoolAllocator) -> bool:
raise NotImplementedError()
@abc.abstractmethod
def has(self, mm_hash: int) -> bool:
raise NotImplementedError()
combine_hashes()는 여러 멀티모달 아이템의 해시를 하나로 결합한다. 이미지 3장이 입력되면 각 이미지의 해시를 tuple로 묶어 전체 조합의 해시를 만든다. 이미지 순서가 다르면 다른 해시가 된다.
EmbeddingResult: 캐시 값 래퍼
캐시에 저장되는 값은 EmbeddingResult 데이터 클래스다. 임베딩 텐서를 감싼다.
@dataclass(kw_only=True)
class EmbeddingResult:
embedding: torch.Tensor
MultiModalStaticCache: LRU 기반 구현
MultiModalStaticCache는 바이트 크기 제한이 있는 LRU 캐시다. Python의 OrderedDict를 활용하여 O(1) LRU를 구현한다.
class MultiModalStaticCache(MultimodalCache):
def __init__(self, max_size: int):
super().__init__()
self.max_size = max_size
self.mm_cache: OrderedDict[int, EmbeddingResult] = OrderedDict()
self.current_size = 0
조회 (get)
여러 이미지의 해시를 결합한 combined_hash로 조회한다. 히트 시 move_to_end()로 MRU 위치로 승격한다.
def get(self, mm_hashes, combined_hash=None):
combined_hash = self.combine_hashes(mm_hashes)
embedding = self.mm_cache.get(combined_hash)
if embedding is not None:
self.mm_cache.move_to_end(combined_hash)
return embedding
단일 아이템 조회용 get_single()도 제공한다. 이는 combine_hashes를 거치지 않고 단일 해시로 직접 조회한다.
def get_single(self, mm_hash: int) -> Optional[EmbeddingResult]:
embedding = self.mm_cache.get(mm_hash)
if embedding is not None:
self.mm_cache.move_to_end(mm_hash)
return embedding
삽입 (set)
이미 존재하는 해시이면 MRU로 승격만 한다. 새로운 항목이면 용량을 확인하고, 초과 시 LRU 항목을 제거한다.
def set(self, mm_hash, embedding, loc=None) -> bool:
assert isinstance(embedding, EmbeddingResult), embedding
if mm_hash in self.mm_cache:
self.mm_cache.move_to_end(mm_hash)
return True
data_size = _get_tensor_size(embedding.embedding)
while self.current_size + data_size > self.max_size:
if not self.mm_cache:
return False # 캐시가 비어도 단일 항목이 max_size 초과
lru_hash, lru_embedding = self.mm_cache.popitem(last=False)
self.current_size -= _get_tensor_size(lru_embedding.embedding)
self.mm_cache[mm_hash] = embedding
self.current_size += data_size
return True
popitem(last=False)가 LRU eviction의 핵심이다. OrderedDict에서 last=False는 가장 먼저 삽입된(가장 오래 접근 안 된) 항목을 제거한다.
크기 계산
텐서의 바이트 크기를 정확히 계산한다. element_size와 numel의 곱이다.
def _get_tensor_size(embedding: torch.Tensor):
return embedding.element_size() * embedding.numel()
해제와 초기화
개별 해시를 해제하거나 전체를 초기화할 수 있다.
def free(self, mm_hash, mm_embedding_allocator) -> bool:
if mm_hash not in self.mm_cache:
return False
old_embedding = self.mm_cache.pop(mm_hash)
self.current_size -= _get_tensor_size(old_embedding.embedding)
return True
def clear(self):
self.mm_cache.clear()
self.current_size = 0
캐시 키 설계 분석
캐시 키의 설계는 히트율에 직접적인 영향을 미친다.
단일 이미지:
hash(image_bytes) → mm_hash
복수 이미지 조합:
[hash(img1), hash(img2), hash(img3)]
→ combine_hashes → hash(tuple([h1, h2, h3]))
같은 이미지셋, 다른 순서:
[h1, h2] ≠ [h2, h1] ← 다른 해시 (순서 의미 있음)
같은 이미지셋, 같은 순서:
[h1, h2] == [h1, h2] ← 같은 해시 (캐시 히트)
순서를 고려하는 이유는 멀티모달 모델에서 이미지 순서가 attention 패턴에 영향을 주기 때문이다. 이미지 A, B 순서와 B, A 순서에서 인코딩 결과가 달라질 수 있다.
KV Cache와의 차이점
| 항목 | KV Cache (RadixCache) | Multimodal Cache |
|---|---|---|
| 캐시 키 | 토큰 시퀀스 (prefix) | 이미지 해시 조합 |
| 캐시 값 | KV 인덱스 (메모리 풀 참조) | 임베딩 텐서 (직접 저장) |
| 크기 관리 | 토큰 수 기반 | 바이트 수 기반 |
| 데이터 구조 | Radix Tree | OrderedDict |
| 공유 패턴 | prefix 공유 (트리) | 전체 조합 매칭 |
| Eviction | 다양한 정책 선택 가능 | LRU 고정 |
Radix Cache는 토큰 prefix를 트리로 공유하므로 부분 매칭이 가능하지만, Multimodal Cache는 이미지 조합 전체가 일치해야 히트한다. 이미지 인코딩은 prefix 공유 개념이 없기 때문이다.
설계 근거
Vision Encoder는 일반적으로 ViT(Vision Transformer) 기반이며, 이미지 하나당 수백 개의 패치 토큰을 생성한다. 이 연산은 LLM의 prefill보다는 가볍지만, 동일 이미지에 대한 반복 인코딩은 불필요하다.
OrderedDict 기반 LRU를 선택한 이유는 단순성과 효율성의 균형이다. 멀티모달 캐시의 항목 수는 KV Cache 노드 수보다 훨씬 적으므로(이미지 수 << 토큰 수), 복잡한 eviction 정책이 필요하지 않다.
관련 포스트
- 캐시 Eviction 정책: LRU, LFU, FIFO 비교 분석
- Session-Aware Cache: 사용자별 KV 캐시 파티셔닝
- Mamba Radix Cache: SSM 모델을 위한 상태 캐싱
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] 외부 스토리지 백엔드: LMCache, 3FS, Mooncake, NIXL
- 현재글 : [SGLang] Multimodal Cache: Vision Encoder 출력 캐싱
- 다음글 [SGLang] RadixAttention Layer: 통합 어텐션 인터페이스의 설계
댓글