본문으로 건너뛰기

[SGLang] Sliding Window Attention 캐시: SWA 최적화 설계

들어가며

Mistral, Mixtral 등의 모델은 Sliding Window Attention(SWA)을 사용한다. 이 모델들은 attention 윈도우 바깥의 토큰을 참조하지 않으므로, 해당 토큰의 KV 캐시는 불필요하다. 그런데 일반 RadixCache는 이 사실을 모른다. 모든 토큰의 KV 캐시를 동일하게 취급하여 윈도우 밖의 캐시까지 GPU 메모리에 유지한다.

SGLang의 SWARadixCache는 Full Attention 레이어와 SWA 레이어의 KV 캐시를 분리 관리하여 이 낭비를 제거한다. 이 글에서는 python/sglang/srt/mem_cache/swa_radix_cache.py를 중심으로 SWA 캐시의 설계를 분석한다.

이중 KV 캐시 관리

SWARadixCache의 핵심은 하나의 Radix Tree에서 두 종류의 KV 캐시를 동시에 관리하는 것이다.

┌─────────────────────────────────────────────────────────────┐
│  Radix Tree (단일 트리)                                     │
│                                                             │
│  Root ─── [system prompt tokens]                            │
│           │                                                 │
│           ├── [user query A]                                │
│           │   full_lock_ref=1, swa_lock_ref=0               │
│           │   → Full KV: 보존, SWA KV: evict 가능           │
│           │                                                 │
│           └── [user query B] ← sliding window 범위          │
│               full_lock_ref=1, swa_lock_ref=1               │
│               → Full KV: 보존, SWA KV: 보존                │
│                                                             │
│  ┌──────────────┐  ┌──────────────┐                         │
│  │ Full LRU List│  │ SWA LRU List │ ← 독립적 eviction       │
│  └──────────────┘  └──────────────┘                         │
└─────────────────────────────────────────────────────────────┘

TreeNode: 이중 참조 카운트

SWA 전용 TreeNode는 두 개의 독립적인 lock_ref를 가진다.

class TreeNode:
    def __init__(self, id: Optional[int] = None):
        self.children = defaultdict(TreeNode)
        self.parent: TreeNode = None
        self.key: RadixKey = None
        self.value: Optional[torch.Tensor] = None
        self.swa_tombstone = False       # SWA KV가 해제되었는지
        self.full_lock_ref = 0           # Full attention 참조 카운트
        self.swa_lock_ref = 0            # SWA 참조 카운트

        # 이중 연결 리스트용 포인터 (Full LRU)
        self.prev = None
        self.next = None
        # SWA LRU용 별도 포인터
        self.swa_prev = None
        self.swa_next = None

        self.swa_uuid = None  # SWA lock 경계 식별자

핵심 불변식이 있다. swa_lock_ref > 0이면 반드시 full_lock_ref > 0이다. 역방향은 성립하지 않는다. 즉, SWA가 보호하는 노드는 Full도 반드시 보호하지만, Full이 보호하는 노드가 SWA까지 보호하지는 않는다.

LRU List: 이중 연결 리스트

SWARadixCache는 heap 대신 이중 연결 리스트로 LRU를 관리한다. Full용과 SWA용 두 개의 독립적 리스트가 있다.

class LRUList:
    def __init__(self, is_swa_list: bool = False):
        self.is_swa_list = is_swa_list
        if self.is_swa_list:
            self.prv = "swa_prev"
            self.nxt = "swa_next"
            self.lock_ref = "swa_lock_ref"
        else:
            self.prv = "prev"
            self.nxt = "next"
            self.lock_ref = "full_lock_ref"
        # 더미 head/tail로 경계 조건 단순화
        self.head = TreeNode()  # Most recently used
        self.tail = TreeNode()  # Least recently used
        setattr(self.head, self.nxt, self.tail)
        setattr(self.tail, self.prv, self.head)
        self.cache = {}

같은 노드가 두 리스트에 동시에 존재하되, 각 리스트에서의 위치(MRU/LRU)는 독립적으로 관리된다. getattr/setattr을 사용하여 하나의 LRUList 구현으로 두 종류의 포인터를 처리한다.

inc_lock_ref: 슬라이딩 윈도우 범위 계산

요청이 노드를 참조할 때, Full lock은 루트까지 전체 경로에 걸지만, SWA lock은 sliding_window_size 범위만 건다.

def inc_lock_ref(self, node: TreeNode) -> IncLockRefResult:
    swa_lock_size = 0
    swa_uuid_for_lock = None

    while node != self.root_node:
        # Full: 루트까지 전부 잠금
        if node.full_lock_ref == 0:
            self.full_evictable_size_ -= len(node.value)
            self.full_protected_size_ += len(node.value)
        node.full_lock_ref += 1

        # SWA: sliding_window_size만큼만 잠금
        if swa_lock_size < self.sliding_window_size:
            if node.swa_lock_ref == 0:
                self.swa_evictable_size_ -= len(node.value)
                self.swa_protected_size_ += len(node.value)
            node.swa_lock_ref += 1
            swa_lock_size += len(node.value)
            if swa_lock_size >= self.sliding_window_size:
                if node.swa_uuid is None:
                    node.swa_uuid = gen_swa_uuid()
                swa_uuid_for_lock = node.swa_uuid
        node = node.parent

    return IncLockRefResult(swa_uuid_for_lock=swa_uuid_for_lock)

swa_uuid_for_lock는 SWA lock의 경계 노드를 식별한다. 나중에 dec_lock_ref에서 이 UUID를 사용하여 정확히 같은 범위를 해제한다.

노드 경로: RootABCD (last_node)
                                ↑
sliding_window_size에 도달    SWA lock 경계
                              swa_uuid 부여

full_lock_ref: A=1, B=1, C=1, D=1  (전부)
swa_lock_ref:       B=0, C=1, D=1  (윈도우 범위만)

Tombstone: SWA KV 해제

윈도우 밖 노드의 SWA KV를 해제할 때, 노드 자체를 삭제하지 않고 tombstone으로 표시한다. Full KV는 여전히 유효하기 때문이다.

def _tombstone_internal_node(self, x: TreeNode):
    """내부 노드를 SWA tombstone으로 표시"""
    x.swa_tombstone = True
    # SWA LRU 리스트에서 제거 (이미 remove_node으로 처리)
    # Full LRU 리스트에서는 유지

Eviction 시 tombstone 리프가 생기면 연쇄적으로 정리한다.

def _iteratively_delete_tombstone_leaf(self, x):
    """tombstone 리프 노드를 연쇄 삭제"""
    leaf_full_num_evicted = 0
    while (x.parent and len(x.parent.children) == 1
           and x.parent.swa_tombstone
           and x.parent != self.root_node):
        # tombstone 부모가 유일한 자식을 잃으면 삭제
        parent = x.parent
        self.full_lru_list.remove_node(parent)
        self._delete_leaf(parent)
        leaf_full_num_evicted += len(parent.value)
        x = parent
    return x, leaf_full_num_evicted

Eviction: 2단계 전략

Eviction은 Full과 SWA를 순서대로 처리한다.

def evict(self, params: EvictParams) -> EvictResult:
    full_num_tokens = params.num_tokens
    swa_num_tokens = params.swa_num_tokens
    full_num_evicted = 0
    swa_num_evicted = 0

    # 1단계: Full eviction (리프 노드만, Full + SWA 모두 해제)
    if full_num_tokens > 0:
        x = self.full_lru_list.get_leaf_lru_no_lock()
        while full_num_evicted < full_num_tokens and ...:
            self.token_to_kv_pool_allocator.free(x.value)
            full_num_evicted += len(x.value)
            swa_num_evicted += len(x.value)
            self._delete_leaf(x)
            ...

    # 2단계: SWA 추가 eviction (내부 노드도 대상)
    if swa_num_evicted < swa_num_tokens:
        x = self.swa_lru_list.get_lru_no_lock()
        while swa_num_evicted < swa_num_tokens and ...:
            if len(x.children) > 0:
                # 내부 노드: SWA만 해제, tombstone 처리
                self.token_to_kv_pool_allocator.free_swa(x.value)
                swa_num_evicted += len(x.value)
                self._tombstone_internal_node(x)
            else:
                # 리프 노드: Full + SWA 모두 해제
                self.token_to_kv_pool_allocator.free(x.value)
                full_num_evicted += len(x.value)
                swa_num_evicted += len(x.value)
                self._delete_leaf(x)

Full eviction은 리프 노드만 대상으로 하고, SWA eviction은 내부 노드의 SWA KV도 해제할 수 있다. 내부 노드의 SWA를 해제하면 tombstone이 되어 Full KV는 유지되지만 SWA KV 메모리를 회수한다.

SWA Allocator와의 연동

SWARadixCache는 SWATokenToKVPoolAllocator와 함께 동작한다. 이 allocator는 Full과 SWA 메모리 풀을 분리 관리한다.

class SWARadixCache(BasePrefixCache):
    def __init__(self, params: CacheInitParams):
        assert isinstance(
            params.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
        self.sliding_window_size = params.sliding_window_size

free_swa는 SWA 풀만 해제하고, free는 Full과 SWA 풀을 모두 해제한다. 이 구분이 tombstone 메커니즘을 가능하게 한다.

설계 근거

이중 LRU를 사용하는 이유: Heap 기반 eviction은 삽입/삭제마다 O(log n) 비용이 든다. SWA에서는 eviction이 빈번하므로, O(1)인 이중 연결 리스트가 더 효율적이다. 또한 Full과 SWA의 eviction 순서가 독립적이어야 하므로, 두 개의 별도 리스트가 필요하다.

Tombstone을 사용하는 이유: SWA KV를 해제할 때 노드를 삭제하면 트리 구조가 깨진다. 해당 노드의 자식이 존재하면 Full KV를 가진 자식과의 연결을 유지해야 한다. Tombstone은 트리 구조를 유지하면서 SWA 메모리만 회수하는 깔끔한 해법이다.

swa_uuid를 사용하는 이유: SWA lock의 경계 노드를 정확히 기억해야 dec_lock_ref에서 같은 범위를 해제할 수 있다. 노드 분할로 인해 경계 노드의 토큰 수가 변할 수 있으므로, 토큰 수 대신 UUID로 경계를 식별한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글