[SGLang] C++ Radix Tree: 고성능 캐시를 위한 네이티브 구현
들어가며
Python으로 구현된 RadixCache는 기능적으로 완전하지만, 수십만 개의 노드를 관리할 때 GIL과 Python 객체 오버헤드가 병목이 된다. SGLang은 이 문제를 해결하기 위해 C++20으로 Radix Tree V2를 구현하고, PyTorch의 cpp_extension을 통해 Python에 바인딩했다.
이 글에서는 python/sglang/srt/mem_cache/cpp_radix_tree/ 디렉토리의 코드를 중심으로 C++ 구현의 설계를 분석한다.
전체 구조
C++ Radix Tree는 다음과 같은 파일 구조로 구성된다.
cpp_radix_tree/
├── common.h # 기본 타입 정의 (token_t, NodeHandle 등)
├── tree_v2_node.h # TreeNode 구조체 (노드 데이터 + 연산)
├── tree_v2_impl.h # RadixTree::Impl (트리 핵심 알고리즘)
├── tree_v2.h # RadixTree 공개 인터페이스
├── tree_v2.cpp # RadixTree 메서드 구현
├── tree_v2_binding.cpp # pybind11 바인딩
├── tree_v2_debug.cpp # 디버그 출력
└── radix_tree.py # Python 타입 힌트 + JIT 컴파일
Python 바인딩: JIT 컴파일
Python에서 C++ 모듈을 로드하는 방식이 특이하다. torch.utils.cpp_extension.load를 사용하여 런타임에 JIT 컴파일한다.
from torch.utils.cpp_extension import load
_abs_path = os.path.dirname(os.path.abspath(__file__))
radix_tree_cpp = load(
name="radix_tree_cpp",
sources=[
f"{_abs_path}/tree_v2_binding.cpp",
f"{_abs_path}/tree_v2_debug.cpp",
f"{_abs_path}/tree_v2.cpp",
],
extra_cflags=["-O3", "-std=c++20"],
)
-O3최적화와 C++20 표준을 사용한다. C++20의 std::ranges, std::span, std::source_location 등 모던 기능을 적극 활용하기 위한 선택이다.
기본 타입 정의: common.h
C++ 코드 전체에서 사용하는 기본 타입들이 정의되어 있다.
namespace radix_tree_v2 {
using token_t = std::int32_t;
using token_vec_t = std::vector<token_t>;
using token_slice = std::span<const token_t>;
using NodeHandle = std::size_t;
using IOTicket = std::uint32_t;
}
token_slice는 C++20의 std::span을 사용한다. 복사 없이 토큰 시퀀스의 일부를 참조할 수 있어 성능에 유리하다. NodeHandle은 노드를 식별하는 정수 ID로, Python 바인딩에서 포인터 대신 사용하여 메모리 안전성을 확보한다.
TreeNode: 노드 데이터 구조
C++ TreeNode은 Python 버전보다 훨씬 정교한 상태 관리를 지원한다.
struct TreeNode {
using childern_map_t = std::unordered_map<
token_vec_t, std::unique_ptr<TreeNode>, std_vector_hash>;
TreeNode(std::size_t node_id_)
: ref_count(0), hit_count(0),
m_io_locked(std::nullopt),
m_io_status(IOStatus::None),
m_tokens(), m_device_indices(), m_host_indices(),
m_parent(), m_children(),
m_last_access_time(std::chrono::steady_clock::now()),
node_id(node_id_) {}
Python 버전과 비교한 핵심 차이점 세 가지가 있다.
첫째, 소유권 관리. 자식 노드를 std::unique_ptr로 관리한다. 노드가 삭제될 때 하위 트리가 자동으로 정리된다.
둘째, IO 상태 추적. IOStatus enum으로 GPU/CPU 간 데이터 전송 상태를 추적한다.
enum class IOStatus : std::uint8_t {
None, // IO 작업 없음
HostToDevice, // CPU→GPU 로딩 중
DeviceToHost, // GPU→CPU 백업 중
};
셋째, 이중 인덱스. m_device_indices(GPU)와 m_host_indices(CPU)를 동시에 관리하여 HiCache의 계층적 캐싱을 네이티브로 지원한다.
bool on_gpu() const { return m_device_indices.defined(); }
bool on_cpu() const { return m_host_indices.defined(); }
bool on_both() const { return on_gpu() && on_cpu(); }
노드 분할: split_prefix
노드 분할은 friend 함수로 구현되어 private 멤버에 직접 접근한다.
friend void split_prefix(TreeNode* new_node, TreeNode* old_node,
std::size_t prefix_length) {
auto tokens = std::move(old_node->m_tokens);
// 토큰 분할
old_node->m_tokens = token_vec_t(
tokens.begin() + prefix_length, tokens.end());
new_node->m_tokens = std::move(tokens);
new_node->m_tokens.resize(prefix_length);
// GPU/CPU 인덱스 텐서 분할
if (old_node->m_device_indices.defined()) {
auto new_indices = old_node->m_device_indices
.split_with_sizes({new_size, old_size});
new_node->m_device_indices = std::move(new_indices[0]);
old_node->m_device_indices = std::move(new_indices[1]);
}
// host_indices도 동일하게 분할
new_node->ref_count = old_node->ref_count;
new_node->hit_count = old_node->hit_count;
}
Python 버전과 달리 std::move를 사용하여 토큰 벡터의 불필요한 복사를 방지하고, at::Tensor::split_with_sizes로 텐서를 효율적으로 분할한다.
트리 탐색: tree_walk
Impl::tree_walk은 키를 따라 트리를 내려가며 매칭하는 핵심 알고리즘이다.
std::pair<TreeNode*, std::size_t> tree_walk(token_slice key) {
_assert(key.size() % page_size == 0, "Key should be page-aligned");
std::size_t total_prefix_length = 0;
TreeNode* node = &m_root;
const auto now = std::chrono::steady_clock::now();
while (key.size() > 0) {
const auto iterator = node->find_child(get_key(key));
if (iterator == node->end()) break;
node = iterator->second.get();
const auto prefix_length =
align(node->diff_key(key, page_size) + page_size);
total_prefix_length += prefix_length;
if (prefix_length < node->length()) {
return {split_node(iterator, prefix_length),
total_prefix_length};
}
node->access(now);
key = key.subspan(prefix_length);
}
return {node, total_prefix_length};
}
diff_key 메서드는 C++20의 std::ranges::mismatch를 사용하여 두 토큰 시퀀스의 첫 번째 차이점을 효율적으로 찾는다.
std::size_t diff_key(token_slice key, std::size_t offset) const {
const auto a = token_slice{key}.subspan(offset);
const auto b = token_slice{m_tokens}.subspan(offset);
const auto [it_a, it_b] = std::ranges::mismatch(a, b);
return it_a - a.begin();
}
Eviction
C++ eviction은 std::priority_queue를 사용한 min-heap 기반 LRU이다.
std::vector<at::Tensor> RadixTree::evict(std::size_t num_tokens) {
auto heap = std::priority_queue{
cmp, m_impl->collect_leaves_device()};
std::vector<at::Tensor> evicted_values;
std::size_t num_evict = 0;
while (num_evict < num_tokens && !heap.empty()) {
const auto node = heap.top();
heap.pop();
if (!node->is_io_free()) continue; // IO 중인 노드 건너뛰기
evicted_values.push_back(node->device_indices());
num_evict += node->length();
const auto parent = node->parent();
m_impl->remove_device_node(node);
if (parent->is_leaf_device() && parent->ref_count == 0)
heap.push(parent);
}
return evicted_values;
}
IO 중인 노드(!node->is_io_free())는 건너뛰는 것이 Python 버전에 없는 안전장치다. HiCache에서 GPU-CPU 간 데이터 전송 중인 노드를 실수로 evict하는 것을 방지한다.
Python 인터페이스
Python 측에서는 TYPE_CHECKING 가드를 통해 타입 힌트를 제공하면서, 런타임에는 C++ 객체를 직접 사용한다.
if TYPE_CHECKING:
class RadixTreeCpp:
def match_prefix(self, prefix: List[int]
) -> Tuple[List[torch.Tensor], int, TreeNodeCpp, TreeNodeCpp]:
...
def evict(self, num_tokens: int) -> List[torch.Tensor]:
...
def lock_ref(self, handle: TreeNodeCpp, lock: bool) -> None:
...
def writing_through(self, key, indices
) -> Tuple[List[Tuple[IOHandle, Tensor, Tensor]], int]:
...
else:
RadixTreeCpp = radix_tree_cpp.RadixTree
이 설계는 IDE 자동완성과 타입 검사를 지원하면서도 런타임 오버헤드가 전혀 없다.
설계 근거
C++20을 선택한 이유: std::span, std::ranges, std::source_location 등 모던 C++ 기능이 코드 안전성과 성능을 동시에 제공한다. std::span은 토큰 시퀀스의 부분 참조를 복사 없이 가능하게 하고, std::ranges::mismatch는 SIMD 최적화 가능한 비교 연산을 제공한다.
NodeHandle을 사용하는 이유: Python에서 C++ 포인터를 직접 다루면 메모리 안전 문제가 발생한다. 정수 ID(NodeHandle)로 노드를 식별하고, m_node_map에서 실제 포인터로 변환하는 간접 참조 방식은 안전하면서도 O(1) 접근을 보장한다.
JIT 컴파일을 사용하는 이유: 사전 빌드(wheel 패키징) 대신 JIT 컴파일을 선택하면 다양한 컴파일러와 시스템 환경에 자동으로 적응한다. 초기 로딩은 느리지만, 캐시되어 이후 로딩은 빠르다.
관련 포스트
- RadixAttention: Radix Tree 기반 프리픽스 캐싱의 핵심
- GPU Memory Pool: 블록 기반 KV 캐시 메모리 할당
- HiRadixCache: 계층적 GPU/CPU/Disk KV 캐시
참고
관련 포스트
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [transformers] Hugging Face Transformers: SequenceFeatureExtractor.pad() 최적화로 불필요한 NumPy 배열 재변환 제거
- [uv] uv의 로컬 휠(Wheel) 압축 해제 성능 회귀 문제 해결: astral_async_zip 버전 업데이트
- [cpython] tarfile 스트리밍 모드(r|*) 성능 개선: 파이썬 압축 파일 처리의 숨겨진 병목 제거
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
SGLang 의 다른글
- 이전글 [SGLang] RadixAttention: Radix Tree 기반 프리픽스 캐싱의 핵심
- 현재글 : [SGLang] C++ Radix Tree: 고성능 캐시를 위한 네이티브 구현
- 다음글 [SGLang] GPU Memory Pool: 블록 기반 KV 캐시 메모리 할당
댓글