본문으로 건너뛰기

[SGLang] Mamba Radix Cache: SSM 모델을 위한 상태 캐싱

들어가며

Transformer 기반 모델은 KV Cache를 통해 이전 토큰의 attention 연산 결과를 재활용한다. 그런데 Mamba 같은 SSM(State Space Model) 기반 모델은 attention이 없다. 대신 재귀적 상태(recurrent state)를 유지하며, 이 상태가 이전 토큰들의 정보를 압축해서 담고 있다.

SGLang의 MambaRadixCache는 Transformer의 KV Cache와 Mamba의 재귀 상태를 동시에 관리하는 하이브리드 Radix Tree다. python/sglang/srt/mem_cache/mamba_radix_cache.py를 중심으로 이 구조를 분석한다.

구조도

MambaRadixCache
├── root_node (TreeNode)
   ├── children: defaultdict(TreeNode)
   ├── value: torch.Tensor           KV 인덱스 (Transformer용)
   ├── mamba_value: torch.Tensor      Mamba 상태 인덱스
   ├── full_lock_ref: int             KV  카운터
   └── mamba_lock_ref: int            Mamba  카운터
├── full_lru_list (LRUList)            KV용 LRU 리스트
├── mamba_lru_list (LRUList)           Mamba용 LRU 리스트
├── full_evictable_size_: int
└── mamba_evictable_size_: int

TreeNode 이중 연결 리스트 구조:
┌──────┐    ┌──────┐    ┌──────┐    ┌──────┐
 HEAD │◀──▶│ MRU  │◀──▶│ ...  │◀──▶│ TAIL    full LRU
└──────┘    └──────┘    └──────┘    └──────┘
┌──────┐    ┌──────┐    ┌──────┐    ┌──────┐
 HEAD │◀──▶│ MRU  │◀──▶│ ...  │◀──▶│ TAIL    mamba LRU
└──────┘    └──────┘    └──────┘    └──────┘

TreeNode: 이중 상태를 가진 노드

TreeNode는 Transformer의 KV Cache와 Mamba 상태를 동시에 저장한다. 핵심은 두 종류의 값과 두 종류의 락이 공존한다는 점이다.

class TreeNode:
    def __init__(self, id: Optional[int] = None):
        self.children = defaultdict(TreeNode)
        self.parent: TreeNode = None
        self.key: RadixKey = None
        self.value: Optional[torch.Tensor] = None         # KV 인덱스
        self.mamba_value: Optional[torch.Tensor] = None    # Mamba 상태
        self.full_lock_ref = 0
        self.mamba_lock_ref = 0
        # LRU 이중 연결 리스트 포인터 (full / mamba 분리)
        self.prev = None
        self.next = None
        self.mamba_prev = None
        self.mamba_next = None

valuemamba_value의 eviction 상태는 독립적이다. KV Cache는 evict 됐지만 Mamba 상태는 남아있는 경우가 가능하다. 이를 "tombstone" 노드라고 부른다.

@property
def evicted(self):
    return self.value is None

@property
def mamba_evicted(self):
    return self.mamba_value is None

LRUList: 두 개의 독립적인 LRU 관리

LRUList는 full(KV)과 mamba 각각에 대해 독립적인 LRU 순서를 유지한다. 생성 시 mamba 플래그에 따라 다른 포인터 필드를 사용한다.

class LRUList:
    def __init__(self, mamba: bool = False):
        self.mamba = mamba
        if self.mamba:
            self.prv = "mamba_prev"
            self.nxt = "mamba_next"
            self.lock_ref = "mamba_lock_ref"
        else:
            self.prv = "prev"
            self.nxt = "next"
            self.lock_ref = "full_lock_ref"
        self.head = TreeNode()
        self.tail = TreeNode()
        setattr(self.head, self.nxt, self.tail)
        setattr(self.tail, self.prv, self.head)

LRU에서 eviction 대상을 찾을 때, 락이 걸린 노드를 건너뛴다.

def get_lru_no_lock(self) -> Optional[TreeNode]:
    return self.get_prev_no_lock(self.tail, check_id=False)

def get_prev_no_lock(self, node, check_id=True):
    x = getattr(node, self.prv)
    while getattr(x, self.lock_ref) > 0:
        x = getattr(x, self.prv)
    if x == self.head:
        return None
    return x

락 참조 카운트의 비대칭 설계

inc_lock_ref의 동작에서 full과 mamba의 락 전파 방식이 다르다. Full 락은 노드에서 루트까지 전체 경로를 잠그지만, Mamba 락은 해당 노드만 잠근다.

def inc_lock_ref(self, node: TreeNode) -> IncLockRefResult:
    # mamba: 현재 노드만 보호
    if node.mamba_value is not None:
        if node.mamba_lock_ref == 0:
            self.mamba_evictable_size_ -= len(node.mamba_value)
            self.mamba_protected_size_ += len(node.mamba_value)
        node.mamba_lock_ref += 1

    # full: 루트까지 경로 전체 보호
    while node != self.root_node:
        if node.full_lock_ref == 0:
            self.full_evictable_size_ -= len(node.value)
            self.full_protected_size_ += len(node.value)
        node.full_lock_ref += 1
        node = node.parent

이 비대칭의 이유는 명확하다. KV Cache는 prefix 공유 구조이므로 부모 노드의 KV도 보호해야 한다. 반면 Mamba 상태는 특정 시점의 재귀 상태를 나타내므로 해당 노드의 상태만 보호하면 된다.

Mamba Eviction: Tombstone 패턴

Mamba 상태만 evict할 때, 내부 노드는 KV Cache를 유지한 채 Mamba 상태만 제거한다. 이것이 tombstone 패턴이다.

def evict_mamba(self, mamba_num: int) -> int:
    x = self.mamba_lru_list.get_lru_no_lock()
    mamba_num_evicted = 0
    while mamba_num_evicted < mamba_num and self.mamba_lru_list.in_list(x):
        if len(x.children) > 0:
            # 내부 노드: mamba만 해제, KV는 유지 (tombstone)
            self.req_to_token_pool.mamba_pool.free(x.mamba_value)
            mamba_num_evicted += len(x.mamba_value)
            x_next = self.mamba_lru_list.get_prev_no_lock(x)
            self.mamba_lru_list.remove_node(x)
            self._tombstone_internal_node(x)
        else:
            # 리프 노드: KV + mamba 모두 해제
            _, mamba_evicted_delta, _, x_next = self._evict_leaf_node(x, True)
            mamba_num_evicted += mamba_evicted_delta
        x = x_next

반면 full eviction은 리프 노드만 제거한다. 리프가 제거되면 부모가 새 리프가 되고, tombstone 리프가 연쇄적으로 정리된다.

def evict_full(self, full_num_tokens: int) -> int:
    x = self.full_lru_list.get_leaf_lru_no_lock()
    while full_num_evicted < full_num_tokens and self.full_lru_list.in_list(x):
        full_num_evicted_delta, _, x, x_next = self._evict_leaf_node(x, False)
        full_num_evicted += full_num_evicted_delta
        if len(x.parent.children) == 0:
            x_next = self.full_lru_list.get_leaf_lru_no_lock()
        x = x_next

설계 근거: 왜 이중 LRU인가

구분 Full (KV Cache) Mamba (State)
데이터 크기 토큰 수에 비례 고정 크기 (레이어당 1개)
공유 구조 prefix 트리로 공유 노드별 독립
락 전파 루트까지 전체 경로 현재 노드만
Eviction 대상 리프 노드만 내부 노드 포함
Eviction 후 노드 삭제 tombstone (KV 유지)

Mamba 상태는 KV와 크기, 공유 패턴, 생명주기가 모두 다르다. 단일 LRU로 관리하면 한쪽이 불필요하게 evict되거나 보호된다. 이중 LRU는 각 리소스를 독립적으로 최적화할 수 있게 한다.

관련 포스트

  • Radix Cache의 기본 구조와 prefix 매칭
  • 캐시 Eviction 정책: LRU, LFU, FIFO 비교 분석
  • Hybrid Cache Controller: GPU/CPU 하이브리드 캐시 관리

참고

댓글

관련 포스트

SGLang 의 다른글