[SGLang] Sliding Window Attention 캐시: SWA 최적화 설계
들어가며
Mistral, Mixtral 등의 모델은 Sliding Window Attention(SWA)을 사용한다. 이 모델들은 attention 윈도우 바깥의 토큰을 참조하지 않으므로, 해당 토큰의 KV 캐시는 불필요하다. 그런데 일반 RadixCache는 이 사실을 모른다. 모든 토큰의 KV 캐시를 동일하게 취급하여 윈도우 밖의 캐시까지 GPU 메모리에 유지한다.
SGLang의 SWARadixCache는 Full Attention 레이어와 SWA 레이어의 KV 캐시를 분리 관리하여 이 낭비를 제거한다. 이 글에서는 python/sglang/srt/mem_cache/swa_radix_cache.py를 중심으로 SWA 캐시의 설계를 분석한다.
이중 KV 캐시 관리
SWARadixCache의 핵심은 하나의 Radix Tree에서 두 종류의 KV 캐시를 동시에 관리하는 것이다.
┌─────────────────────────────────────────────────────────────┐
│ Radix Tree (단일 트리) │
│ │
│ Root ─── [system prompt tokens] │
│ │ │
│ ├── [user query A] │
│ │ full_lock_ref=1, swa_lock_ref=0 │
│ │ → Full KV: 보존, SWA KV: evict 가능 │
│ │ │
│ └── [user query B] ← sliding window 범위 │
│ full_lock_ref=1, swa_lock_ref=1 │
│ → Full KV: 보존, SWA KV: 보존 │
│ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Full LRU List│ │ SWA LRU List │ ← 독립적 eviction │
│ └──────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────┘
TreeNode: 이중 참조 카운트
SWA 전용 TreeNode는 두 개의 독립적인 lock_ref를 가진다.
class TreeNode:
def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode)
self.parent: TreeNode = None
self.key: RadixKey = None
self.value: Optional[torch.Tensor] = None
self.swa_tombstone = False # SWA KV가 해제되었는지
self.full_lock_ref = 0 # Full attention 참조 카운트
self.swa_lock_ref = 0 # SWA 참조 카운트
# 이중 연결 리스트용 포인터 (Full LRU)
self.prev = None
self.next = None
# SWA LRU용 별도 포인터
self.swa_prev = None
self.swa_next = None
self.swa_uuid = None # SWA lock 경계 식별자
핵심 불변식이 있다. swa_lock_ref > 0이면 반드시 full_lock_ref > 0이다. 역방향은 성립하지 않는다. 즉, SWA가 보호하는 노드는 Full도 반드시 보호하지만, Full이 보호하는 노드가 SWA까지 보호하지는 않는다.
LRU List: 이중 연결 리스트
SWARadixCache는 heap 대신 이중 연결 리스트로 LRU를 관리한다. Full용과 SWA용 두 개의 독립적 리스트가 있다.
class LRUList:
def __init__(self, is_swa_list: bool = False):
self.is_swa_list = is_swa_list
if self.is_swa_list:
self.prv = "swa_prev"
self.nxt = "swa_next"
self.lock_ref = "swa_lock_ref"
else:
self.prv = "prev"
self.nxt = "next"
self.lock_ref = "full_lock_ref"
# 더미 head/tail로 경계 조건 단순화
self.head = TreeNode() # Most recently used
self.tail = TreeNode() # Least recently used
setattr(self.head, self.nxt, self.tail)
setattr(self.tail, self.prv, self.head)
self.cache = {}
같은 노드가 두 리스트에 동시에 존재하되, 각 리스트에서의 위치(MRU/LRU)는 독립적으로 관리된다. getattr/setattr을 사용하여 하나의 LRUList 구현으로 두 종류의 포인터를 처리한다.
inc_lock_ref: 슬라이딩 윈도우 범위 계산
요청이 노드를 참조할 때, Full lock은 루트까지 전체 경로에 걸지만, SWA lock은 sliding_window_size 범위만 건다.
def inc_lock_ref(self, node: TreeNode) -> IncLockRefResult:
swa_lock_size = 0
swa_uuid_for_lock = None
while node != self.root_node:
# Full: 루트까지 전부 잠금
if node.full_lock_ref == 0:
self.full_evictable_size_ -= len(node.value)
self.full_protected_size_ += len(node.value)
node.full_lock_ref += 1
# SWA: sliding_window_size만큼만 잠금
if swa_lock_size < self.sliding_window_size:
if node.swa_lock_ref == 0:
self.swa_evictable_size_ -= len(node.value)
self.swa_protected_size_ += len(node.value)
node.swa_lock_ref += 1
swa_lock_size += len(node.value)
if swa_lock_size >= self.sliding_window_size:
if node.swa_uuid is None:
node.swa_uuid = gen_swa_uuid()
swa_uuid_for_lock = node.swa_uuid
node = node.parent
return IncLockRefResult(swa_uuid_for_lock=swa_uuid_for_lock)
swa_uuid_for_lock는 SWA lock의 경계 노드를 식별한다. 나중에 dec_lock_ref에서 이 UUID를 사용하여 정확히 같은 범위를 해제한다.
노드 경로: Root ← A ← B ← C ← D (last_node)
↑
sliding_window_size에 도달 SWA lock 경계
swa_uuid 부여
full_lock_ref: A=1, B=1, C=1, D=1 (전부)
swa_lock_ref: B=0, C=1, D=1 (윈도우 범위만)
Tombstone: SWA KV 해제
윈도우 밖 노드의 SWA KV를 해제할 때, 노드 자체를 삭제하지 않고 tombstone으로 표시한다. Full KV는 여전히 유효하기 때문이다.
def _tombstone_internal_node(self, x: TreeNode):
"""내부 노드를 SWA tombstone으로 표시"""
x.swa_tombstone = True
# SWA LRU 리스트에서 제거 (이미 remove_node으로 처리)
# Full LRU 리스트에서는 유지
Eviction 시 tombstone 리프가 생기면 연쇄적으로 정리한다.
def _iteratively_delete_tombstone_leaf(self, x):
"""tombstone 리프 노드를 연쇄 삭제"""
leaf_full_num_evicted = 0
while (x.parent and len(x.parent.children) == 1
and x.parent.swa_tombstone
and x.parent != self.root_node):
# tombstone 부모가 유일한 자식을 잃으면 삭제
parent = x.parent
self.full_lru_list.remove_node(parent)
self._delete_leaf(parent)
leaf_full_num_evicted += len(parent.value)
x = parent
return x, leaf_full_num_evicted
Eviction: 2단계 전략
Eviction은 Full과 SWA를 순서대로 처리한다.
def evict(self, params: EvictParams) -> EvictResult:
full_num_tokens = params.num_tokens
swa_num_tokens = params.swa_num_tokens
full_num_evicted = 0
swa_num_evicted = 0
# 1단계: Full eviction (리프 노드만, Full + SWA 모두 해제)
if full_num_tokens > 0:
x = self.full_lru_list.get_leaf_lru_no_lock()
while full_num_evicted < full_num_tokens and ...:
self.token_to_kv_pool_allocator.free(x.value)
full_num_evicted += len(x.value)
swa_num_evicted += len(x.value)
self._delete_leaf(x)
...
# 2단계: SWA 추가 eviction (내부 노드도 대상)
if swa_num_evicted < swa_num_tokens:
x = self.swa_lru_list.get_lru_no_lock()
while swa_num_evicted < swa_num_tokens and ...:
if len(x.children) > 0:
# 내부 노드: SWA만 해제, tombstone 처리
self.token_to_kv_pool_allocator.free_swa(x.value)
swa_num_evicted += len(x.value)
self._tombstone_internal_node(x)
else:
# 리프 노드: Full + SWA 모두 해제
self.token_to_kv_pool_allocator.free(x.value)
full_num_evicted += len(x.value)
swa_num_evicted += len(x.value)
self._delete_leaf(x)
Full eviction은 리프 노드만 대상으로 하고, SWA eviction은 내부 노드의 SWA KV도 해제할 수 있다. 내부 노드의 SWA를 해제하면 tombstone이 되어 Full KV는 유지되지만 SWA KV 메모리를 회수한다.
SWA Allocator와의 연동
SWARadixCache는 SWATokenToKVPoolAllocator와 함께 동작한다. 이 allocator는 Full과 SWA 메모리 풀을 분리 관리한다.
class SWARadixCache(BasePrefixCache):
def __init__(self, params: CacheInitParams):
assert isinstance(
params.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
self.sliding_window_size = params.sliding_window_size
free_swa는 SWA 풀만 해제하고, free는 Full과 SWA 풀을 모두 해제한다. 이 구분이 tombstone 메커니즘을 가능하게 한다.
설계 근거
이중 LRU를 사용하는 이유: Heap 기반 eviction은 삽입/삭제마다 O(log n) 비용이 든다. SWA에서는 eviction이 빈번하므로, O(1)인 이중 연결 리스트가 더 효율적이다. 또한 Full과 SWA의 eviction 순서가 독립적이어야 하므로, 두 개의 별도 리스트가 필요하다.
Tombstone을 사용하는 이유: SWA KV를 해제할 때 노드를 삭제하면 트리 구조가 깨진다. 해당 노드의 자식이 존재하면 Full KV를 가진 자식과의 연결을 유지해야 한다. Tombstone은 트리 구조를 유지하면서 SWA 메모리만 회수하는 깔끔한 해법이다.
swa_uuid를 사용하는 이유: SWA lock의 경계 노드를 정확히 기억해야 dec_lock_ref에서 같은 범위를 해제할 수 있다. 노드 분할로 인해 경계 노드의 토큰 수가 변할 수 있으므로, 토큰 수 대신 UUID로 경계를 식별한다.
관련 포스트
- RadixAttention: Radix Tree 기반 프리픽스 캐싱의 핵심
- GPU Memory Pool: 블록 기반 KV 캐시 메모리 할당
- Allocator: 토큰-KV 풀 할당 전략의 설계
- HiRadixCache: 계층적 GPU/CPU/Disk KV 캐시
참고
관련 포스트
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
- [sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서
- [sglang] NixlKVManager 성능 향상: 비동기 및 멀티스레드 KV 전송 도입
SGLang 의 다른글
- 이전글 [SGLang] HiRadixCache: 계층적 GPU/CPU/Disk KV 캐시
- 현재글 : [SGLang] Sliding Window Attention 캐시: SWA 최적화 설계
- 다음글 [SGLang] Mamba Radix Cache: SSM 모델을 위한 상태 캐싱
댓글