[SGLang] RadixAttention: Radix Tree 기반 프리픽스 캐싱의 핵심
들어가며
LLM serving에서 동일한 system prompt를 공유하는 수천 개의 요청이 들어온다고 가정하자. 각 요청마다 system prompt에 대한 KV 캐시를 매번 새로 계산한다면, 동일한 연산을 반복하는 엄청난 낭비가 발생한다. SGLang의 RadixAttention은 이 문제를 Radix Tree 자료구조로 해결한다. 2024년 1월 SGLang 블로그 포스트에서 PagedAttention 대비 최대 5배 성능 향상을 보고한 핵심 기술이다.
이 글에서는 python/sglang/srt/mem_cache/radix_cache.py를 중심으로 RadixAttention의 설계를 분석한다.
PagedAttention vs RadixAttention
두 접근법의 핵심 차이는 KV 캐시를 어떻게 관리하느냐에 있다.
┌──────────────────────────────────────────────────────────────┐
│ PagedAttention (vLLM) │
├──────────────────────────────────────────────────────────────┤
│ 요청 A: [system prompt] → KV 블록 할당 │
│ 요청 B: [system prompt] → KV 블록 재할당 (중복 계산!) │
│ 요청 C: [system prompt] → KV 블록 재할당 (또 중복 계산!) │
│ → 동일 프리픽스에 대해 O(N) 중복 연산 │
├──────────────────────────────────────────────────────────────┤
│ RadixAttention (SGLang) │
├──────────────────────────────────────────────────────────────┤
│ ┌─ [system prompt] ─ KV 캐시 (공유) │
│ │ │
│ Root ────┼─ 요청 A: 캐시 HIT → 즉시 재사용 │
│ ├─ 요청 B: 캐시 HIT → 즉시 재사용 │
│ └─ 요청 C: 캐시 HIT → 즉시 재사용 │
│ → 프리픽스 공유로 O(1) 재사용 │
└──────────────────────────────────────────────────────────────┘
| 비교 항목 | PagedAttention (vLLM) | RadixAttention (SGLang) |
|---|---|---|
| 캐시 단위 | 고정 크기 블록 | Radix Tree 노드 (가변 길이) |
| 프리픽스 공유 | 블록 해시 매칭 (사후적) | 트리 구조 (자동적) |
| 삽입 복잡도 | O(n) 해시 계산 | O(k) 트리 탐색 (k=깊이) |
| 부분 매칭 | 블록 경계에서만 | 토큰 단위 정밀 분할 |
| 멀티턴 대화 | 매번 재계산 | 이전 턴 자동 캐싱 |
| eviction | 블록 단위 LRU | 노드 단위 다중 정책 |
Radix Tree 구조도
SGLang의 Radix Tree는 토큰 시퀀스를 키로, KV 캐시 인덱스를 값으로 저장한다.
Root (lock_ref=1)
/ \
[1,2,3] [8,9,10,11,12]
KV=[0,1,2] KV=[7,8,9,10,11]
/ \
[4,5] [13,14]
KV=[3,4] KV=[5,6]
→ [1,2,3]을 공유하는 두 분기가 KV 캐시를 자동 공유
→ 새 요청 [1,2,3,4,5,6]이 오면 [1,2,3] + [4,5]까지 HIT
핵심 자료구조: TreeNode
트리의 각 노드는 TreeNode 클래스로 표현된다.
class TreeNode:
counter = 0
def __init__(self, id: Optional[int] = None, priority: int = 0):
self.children = defaultdict(TreeNode)
self.parent: TreeNode = None
self.key: RadixKey = None # 토큰 ID 시퀀스
self.value: Optional[torch.Tensor] = None # KV 캐시 인덱스
self.lock_ref = 0 # 참조 카운트 (0이면 evict 가능)
self.last_access_time = time.monotonic()
self.hit_count = 0 # LFU/SLRU 정책용
self.priority = priority # 우선순위 eviction용
self.host_value: Optional[torch.Tensor] = None # CPU 백업용
lock_ref가 핵심이다. 요청이 노드를 참조하면 루트까지의 경로 전체를 lock하여 eviction으로부터 보호한다. value가 None이면 evicted 상태를 나타낸다.
프리픽스 매칭: match_prefix
요청이 들어오면 가장 긴 캐시된 프리픽스를 찾는다.
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
access_time = time.monotonic()
node.last_access_time = access_time
child_key = self.get_child_key_fn(key)
value = []
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
child.last_access_time = access_time
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
# 부분 매칭 → 노드 분할
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
node = new_node
break
else:
value.append(child.value)
node = child
key = key[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
return value, node
핵심 동작은 두 가지다. 첫째, 자식 노드의 키와 입력 키를 비교하며 트리를 내려간다. 둘째, 부분 매칭이 발생하면 _split_node로 노드를 분할하여 정확한 경계를 만든다.
노드 분할: _split_node
Radix Tree의 핵심 연산이다. 기존 노드를 공유 프리픽스 부분과 나머지로 나눈다.
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
new_node = TreeNode(priority=child.priority)
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
new_node.value = child.value[:split_len].clone()
child.parent = new_node
child.key = child.key[split_len:]
child.value = child.value[split_len:].clone()
new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node
분할 전후를 시각화하면 이렇다.
분할 전: Parent → [1,2,3,4,5] (child)
↑ split_len=3
분할 후: Parent → [1,2,3] (new_node) → [4,5] (child)
삽입: _insert_helper
새 KV 캐시를 트리에 삽입하는 과정이다.
def _insert_helper(self, node, key, value, priority=0, chunked=False):
if priority is None:
priority = 0
access_time = time.monotonic()
node.last_access_time = access_time
node.priority = max(node.priority, priority)
child_key = self.get_child_key_fn(key)
total_prefix_length = 0
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
prefix_len = self.key_match_fn(node.key, key)
total_prefix_length += prefix_len
key = key[prefix_len:]
value = value[prefix_len:]
# ...부분 매칭 시 분할...
if len(key):
new_node = TreeNode(priority=priority)
new_node.parent = node
new_node.key = key
new_node.value = value.clone()
node.children[child_key] = new_node
self.evictable_size_ += len(key)
return total_prefix_length
반환값 total_prefix_length는 이미 캐시에 존재하던 프리픽스 길이다. 호출자는 이 값을 사용해 중복 KV 인덱스를 해제한다.
Eviction 전략
SGLang은 7가지 eviction 정책을 지원한다.
if self.eviction_policy == "lru":
self.eviction_strategy = LRUStrategy()
elif self.eviction_policy == "lfu":
self.eviction_strategy = LFUStrategy()
elif self.eviction_policy == "slru":
self.eviction_strategy = SLRUStrategy()
# ... fifo, mru, filo, priority
eviction은 리프 노드부터 시작하여 부모로 전파된다.
def evict(self, params: EvictParams) -> EvictResult:
leaves = list(self.evictable_leaves)
eviction_heap = [
(self.eviction_strategy.get_priority(node), node) for node in leaves
]
heapq.heapify(eviction_heap)
num_evicted = 0
while num_evicted < num_tokens and len(eviction_heap):
_priority, x = heapq.heappop(eviction_heap)
self.token_to_kv_pool_allocator.free(x.value)
num_evicted += len(x.value)
self._delete_leaf(x)
# 부모가 새로운 리프가 되면 힙에 추가
if len(x.parent.children) == 0 and x.parent.lock_ref == 0:
heapq.heappush(eviction_heap, (..., x.parent))
참조 카운트: lock_ref
요청이 캐시 노드를 사용할 때, 해당 노드부터 루트까지 전체 경로를 보호한다.
def inc_lock_ref(self, node: TreeNode) -> IncLockRefResult:
delta = 0
while node != self.root_node:
if node.lock_ref == 0:
self.evictable_size_ -= len(node.key)
self.protected_size_ += len(node.key)
delta -= len(node.key)
node.lock_ref += 1
self._update_leaf_status(node)
node = node.parent
return IncLockRefResult(delta=delta)
이 설계 덕분에 한 요청이 사용 중인 프리픽스는 다른 요청의 eviction으로부터 안전하게 보호된다.
설계 근거
RadixAttention의 핵심 설계 선택 세 가지를 정리한다.
Radix Tree를 선택한 이유: Trie와 달리 Radix Tree는 공통 프리픽스를 하나의 노드로 압축한다. LLM 요청의 토큰 시퀀스는 수천 개에 달하므로, 토큰마다 노드를 만드는 Trie는 메모리와 탐색 비용이 비현실적이다. Radix Tree는 분기가 발생하는 지점에서만 노드를 생성하므로 공간 효율이 높다.
동적 노드 분할: _split_node를 통해 필요한 시점에만 노드를 분할한다. 사전에 고정 크기로 쪼개지 않으므로, 실제 공유 패턴에 맞춰 트리가 자연스럽게 진화한다.
다중 eviction 정책: LRU만으로는 부족한 시나리오가 있다. 예를 들어 SLRU는 hit count 기반으로 자주 사용되는 프리픽스를 보호하고, Priority 정책은 사용자가 명시적으로 중요도를 지정할 수 있게 한다.
관련 포스트
- C++ Radix Tree: 고성능 캐시를 위한 네이티브 구현
- GPU Memory Pool: 블록 기반 KV 캐시 메모리 할당
- Allocator: 토큰-KV 풀 할당 전략의 설계
- HiRadixCache: 계층적 GPU/CPU/Disk KV 캐시
- Sliding Window Attention 캐시: SWA 최적화 설계
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Prefill Delayer: 전략적 프리필 지연으로 디코드 처리량 극대화
- 현재글 : [SGLang] RadixAttention: Radix Tree 기반 프리픽스 캐싱의 핵심
- 다음글 [SGLang] C++ Radix Tree: 고성능 캐시를 위한 네이티브 구현
댓글