본문으로 건너뛰기

[SGLang] RadixAttention: Radix Tree 기반 프리픽스 캐싱의 핵심

들어가며

LLM serving에서 동일한 system prompt를 공유하는 수천 개의 요청이 들어온다고 가정하자. 각 요청마다 system prompt에 대한 KV 캐시를 매번 새로 계산한다면, 동일한 연산을 반복하는 엄청난 낭비가 발생한다. SGLang의 RadixAttention은 이 문제를 Radix Tree 자료구조로 해결한다. 2024년 1월 SGLang 블로그 포스트에서 PagedAttention 대비 최대 5배 성능 향상을 보고한 핵심 기술이다.

이 글에서는 python/sglang/srt/mem_cache/radix_cache.py를 중심으로 RadixAttention의 설계를 분석한다.

PagedAttention vs RadixAttention

두 접근법의 핵심 차이는 KV 캐시를 어떻게 관리하느냐에 있다.

┌──────────────────────────────────────────────────────────────┐
│                    PagedAttention (vLLM)                     │
├──────────────────────────────────────────────────────────────┤
│  요청 A: [system prompt] → KV 블록 할당                      │
│  요청 B: [system prompt] → KV 블록 재할당 (중복 계산!)        │
│  요청 C: [system prompt] → KV 블록 재할당 (또 중복 계산!)     │
│  → 동일 프리픽스에 대해 O(N) 중복 연산                        │
├──────────────────────────────────────────────────────────────┤
│                  RadixAttention (SGLang)                     │
├──────────────────────────────────────────────────────────────┤
│           ┌─ [system prompt] ─ KV 캐시 (공유)                │
│           │                                                  │
│  Root ────┼─ 요청 A: 캐시 HIT → 즉시 재사용                  │
│           ├─ 요청 B: 캐시 HIT → 즉시 재사용                  │
│           └─ 요청 C: 캐시 HIT → 즉시 재사용                  │
│  → 프리픽스 공유로 O(1) 재사용                                │
└──────────────────────────────────────────────────────────────┘
비교 항목 PagedAttention (vLLM) RadixAttention (SGLang)
캐시 단위 고정 크기 블록 Radix Tree 노드 (가변 길이)
프리픽스 공유 블록 해시 매칭 (사후적) 트리 구조 (자동적)
삽입 복잡도 O(n) 해시 계산 O(k) 트리 탐색 (k=깊이)
부분 매칭 블록 경계에서만 토큰 단위 정밀 분할
멀티턴 대화 매번 재계산 이전 턴 자동 캐싱
eviction 블록 단위 LRU 노드 단위 다중 정책

Radix Tree 구조도

SGLang의 Radix Tree는 토큰 시퀀스를 키로, KV 캐시 인덱스를 값으로 저장한다.

                    Root (lock_ref=1)
                   /        \
          [1,2,3]            [8,9,10,11,12]
         KV=[0,1,2]          KV=[7,8,9,10,11]
        /         \
   [4,5]          [13,14]
  KV=[3,4]       KV=[5,6]

  → [1,2,3]을 공유하는 두 분기가 KV 캐시를 자동 공유
  → 새 요청 [1,2,3,4,5,6]이 오면 [1,2,3] + [4,5]까지 HIT

핵심 자료구조: TreeNode

트리의 각 노드는 TreeNode 클래스로 표현된다.

class TreeNode:
    counter = 0

    def __init__(self, id: Optional[int] = None, priority: int = 0):
        self.children = defaultdict(TreeNode)
        self.parent: TreeNode = None
        self.key: RadixKey = None        # 토큰 ID 시퀀스
        self.value: Optional[torch.Tensor] = None  # KV 캐시 인덱스
        self.lock_ref = 0                # 참조 카운트 (0이면 evict 가능)
        self.last_access_time = time.monotonic()
        self.hit_count = 0               # LFU/SLRU 정책용
        self.priority = priority         # 우선순위 eviction용
        self.host_value: Optional[torch.Tensor] = None  # CPU 백업용

lock_ref가 핵심이다. 요청이 노드를 참조하면 루트까지의 경로 전체를 lock하여 eviction으로부터 보호한다. valueNone이면 evicted 상태를 나타낸다.

프리픽스 매칭: match_prefix

요청이 들어오면 가장 긴 캐시된 프리픽스를 찾는다.

def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
    access_time = time.monotonic()
    node.last_access_time = access_time
    child_key = self.get_child_key_fn(key)

    value = []
    while len(key) > 0 and child_key in node.children.keys():
        child = node.children[child_key]
        child.last_access_time = access_time
        prefix_len = self.key_match_fn(child.key, key)
        if prefix_len < len(child.key):
            # 부분 매칭 → 노드 분할
            new_node = self._split_node(child.key, child, prefix_len)
            value.append(new_node.value)
            node = new_node
            break
        else:
            value.append(child.value)
            node = child
            key = key[prefix_len:]
            if len(key):
                child_key = self.get_child_key_fn(key)

    return value, node

핵심 동작은 두 가지다. 첫째, 자식 노드의 키와 입력 키를 비교하며 트리를 내려간다. 둘째, 부분 매칭이 발생하면 _split_node로 노드를 분할하여 정확한 경계를 만든다.

노드 분할: _split_node

Radix Tree의 핵심 연산이다. 기존 노드를 공유 프리픽스 부분과 나머지로 나눈다.

def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
    new_node = TreeNode(priority=child.priority)
    new_node.children = {self.get_child_key_fn(key[split_len:]): child}
    new_node.parent = child.parent
    new_node.lock_ref = child.lock_ref
    new_node.key = child.key[:split_len]
    new_node.value = child.value[:split_len].clone()
    child.parent = new_node
    child.key = child.key[split_len:]
    child.value = child.value[split_len:].clone()
    new_node.parent.children[self.get_child_key_fn(key)] = new_node
    return new_node

분할 전후를 시각화하면 이렇다.

분할 전:  Parent → [1,2,3,4,5] (child)
                     ↑ split_len=3

분할 후:  Parent → [1,2,3] (new_node) → [4,5] (child)

삽입: _insert_helper

새 KV 캐시를 트리에 삽입하는 과정이다.

def _insert_helper(self, node, key, value, priority=0, chunked=False):
    if priority is None:
        priority = 0
    access_time = time.monotonic()
    node.last_access_time = access_time
    node.priority = max(node.priority, priority)

    child_key = self.get_child_key_fn(key)
    total_prefix_length = 0

    while len(key) > 0 and child_key in node.children.keys():
        node = node.children[child_key]
        prefix_len = self.key_match_fn(node.key, key)
        total_prefix_length += prefix_len
        key = key[prefix_len:]
        value = value[prefix_len:]
        # ...부분 매칭 시 분할...

    if len(key):
        new_node = TreeNode(priority=priority)
        new_node.parent = node
        new_node.key = key
        new_node.value = value.clone()
        node.children[child_key] = new_node
        self.evictable_size_ += len(key)
    return total_prefix_length

반환값 total_prefix_length는 이미 캐시에 존재하던 프리픽스 길이다. 호출자는 이 값을 사용해 중복 KV 인덱스를 해제한다.

Eviction 전략

SGLang은 7가지 eviction 정책을 지원한다.

if self.eviction_policy == "lru":
    self.eviction_strategy = LRUStrategy()
elif self.eviction_policy == "lfu":
    self.eviction_strategy = LFUStrategy()
elif self.eviction_policy == "slru":
    self.eviction_strategy = SLRUStrategy()
# ... fifo, mru, filo, priority

eviction은 리프 노드부터 시작하여 부모로 전파된다.

def evict(self, params: EvictParams) -> EvictResult:
    leaves = list(self.evictable_leaves)
    eviction_heap = [
        (self.eviction_strategy.get_priority(node), node) for node in leaves
    ]
    heapq.heapify(eviction_heap)

    num_evicted = 0
    while num_evicted < num_tokens and len(eviction_heap):
        _priority, x = heapq.heappop(eviction_heap)
        self.token_to_kv_pool_allocator.free(x.value)
        num_evicted += len(x.value)
        self._delete_leaf(x)
        # 부모가 새로운 리프가 되면 힙에 추가
        if len(x.parent.children) == 0 and x.parent.lock_ref == 0:
            heapq.heappush(eviction_heap, (..., x.parent))

참조 카운트: lock_ref

요청이 캐시 노드를 사용할 때, 해당 노드부터 루트까지 전체 경로를 보호한다.

def inc_lock_ref(self, node: TreeNode) -> IncLockRefResult:
    delta = 0
    while node != self.root_node:
        if node.lock_ref == 0:
            self.evictable_size_ -= len(node.key)
            self.protected_size_ += len(node.key)
            delta -= len(node.key)
        node.lock_ref += 1
        self._update_leaf_status(node)
        node = node.parent
    return IncLockRefResult(delta=delta)

이 설계 덕분에 한 요청이 사용 중인 프리픽스는 다른 요청의 eviction으로부터 안전하게 보호된다.

설계 근거

RadixAttention의 핵심 설계 선택 세 가지를 정리한다.

Radix Tree를 선택한 이유: Trie와 달리 Radix Tree는 공통 프리픽스를 하나의 노드로 압축한다. LLM 요청의 토큰 시퀀스는 수천 개에 달하므로, 토큰마다 노드를 만드는 Trie는 메모리와 탐색 비용이 비현실적이다. Radix Tree는 분기가 발생하는 지점에서만 노드를 생성하므로 공간 효율이 높다.

동적 노드 분할: _split_node를 통해 필요한 시점에만 노드를 분할한다. 사전에 고정 크기로 쪼개지 않으므로, 실제 공유 패턴에 맞춰 트리가 자연스럽게 진화한다.

다중 eviction 정책: LRU만으로는 부족한 시나리오가 있다. 예를 들어 SLRU는 hit count 기반으로 자주 사용되는 프리픽스를 보호하고, Priority 정책은 사용자가 명시적으로 중요도를 지정할 수 있게 한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글