[SGLang] Staging Buffer: KV 캐시 전송 버퍼 관리
들어가며
Prefill과 Decode 서버의 TP(Tensor Parallel) 크기가 다른 heterogeneous TP 환경에서, KV 캐시를 토큰 단위로 RDMA 전송하면 O(tokens * layers) 개의 소규모 요청이 발생한다. SGLang의 Staging Buffer는 흩어진 head 슬라이스를 연속 GPU 메모리로 모은 뒤 벌크 RDMA 전송을 수행하여, 요청 수를 O(layers) 또는 O(1)로 줄인다.
구조도
Prefill 서버 (TP=4) Decode 서버 (TP=2)
┌─────────────────────┐ ┌─────────────────────┐
│ KV Pool │ │ KV Pool │
│ [head0][head1]... │ │ [head0~1][head2~3] │
│ 분산된 head 슬라이스 │ │ │
└────────┬────────────┘ └──────────▲──────────┘
│ gather │ scatter
▼ │
┌─────────────────────┐ RDMA ┌──────────┴──────────┐
│ Staging Buffer │ ─────────► │ Staging Allocator │
│ (Prefill측, 연속) │ bulk xfer │ (Decode측, ring buf) │
└─────────────────────┘ └─────────────────────┘
핵심 코드 분석
StagingBuffer: GPU 스테이징 메모리
python/sglang/srt/disaggregation/common/staging_buffer.py의 StagingBuffer는 벌크 전송을 위한 연속 GPU 메모리를 할당한다.
class StagingBuffer:
def __init__(self, size_bytes, device, gpu_id, custom_mem_pool=None):
torch.cuda.set_device(gpu_id)
if custom_mem_pool is not None:
with torch.cuda.use_mem_pool(custom_mem_pool):
self.buffer = torch.empty(size_bytes, dtype=torch.uint8, device=device)
else:
self.buffer = torch.empty(size_bytes, dtype=torch.uint8, device=device)
self.data_ptr = self.buffer.data_ptr()
custom_mem_pool이 제공되면 cuMemCreate 기반 메모리 풀에서 할당하여 NVLink/MNNVL 전송과 호환되도록 한다.
StagingAllocator: Ring Buffer 방식 할당
Decode 측에서는 Ring Buffer 방식의 동적 할당기를 사용한다. 오버커밋을 허용하고, watermark로 안전 영역을 추적한다.
class StagingAllocator:
ALLOC_OVERSIZED = -2
def __init__(self, total_size_bytes, device, gpu_id, custom_mem_pool=None):
self.buffer = StagingBuffer(total_size_bytes, device, gpu_id, custom_mem_pool)
self.total_size = total_size_bytes
self.base_ptr = self.buffer.data_ptr
self.head = 0
self.round = 0
self.allocations: dict = {}
self.alloc_order: List[int] = []
할당과 해제
Ring Buffer에서 공간을 할당하고, 해제 시 watermark를 전진시킨다.
def assign(self, required_bytes: int) -> Optional[Tuple[int, int, int]]:
with self.lock:
if required_bytes > self.total_size:
return None
space_at_end = self.total_size - self.head
if required_bytes <= space_at_end:
offset = self.head
self.head += required_bytes
else:
self.round += 1
offset = 0
self.head = required_bytes
alloc_id = self.next_alloc_id
self.next_alloc_id += 1
self.allocations[alloc_id] = (offset, required_bytes, self.round)
return (alloc_id, offset, self.round)
def free(self, alloc_id: int):
with self.lock:
self.allocations.pop(alloc_id)
while self.alloc_order and self.alloc_order[0] not in self.allocations:
self.alloc_order.pop(0)
if not self.allocations:
self.watermark_round = self.round
self.watermark_tail = self.head
Watermark (round, tail_offset) 이전의 영역은 Prefill 서버가 안전하게 쓸 수 있다.
Triton Gather 커널: GPU 내 데이터 모으기
KV 풀에서 흩어진 head 슬라이스를 Staging Buffer로 모으는 Triton 커널이다.
@triton.jit
def _fused_gather_to_staging_kernel(
layer_ptrs, page_indices, staging, num_tokens,
stride_pool_token, head_offset, per_layer_elems,
ELEMS_PER_TOKEN: tl.constexpr, PAGE_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
layer_id = tl.program_id(0)
block_id = tl.program_id(1)
layer_ptr = tl.load(layer_ptrs + layer_id).to(staging.dtype)
offsets = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
t_idx = offsets // ELEMS_PER_TOKEN
e_idx = offsets % ELEMS_PER_TOKEN
page_id = t_idx // PAGE_SIZE
intra_page = t_idx % PAGE_SIZE
page_val = tl.load(page_indices + page_id, mask=mask, other=0)
pool_token = page_val * PAGE_SIZE + intra_page
src_offsets = pool_token * stride_pool_token + head_offset + e_idx
vals = tl.load(layer_ptr + src_offsets, mask=mask)
dst_offsets = tl.program_id(0) * per_layer_elems + offsets
tl.store(staging + dst_offsets, vals, mask=mask)
(layer_id, block_id) 2D 그리드로 모든 레이어의 K/V 버퍼를 단일 커널 호출로 처리한다.
Triton Scatter 커널: 스테이징에서 KV 풀로
Decode 측에서 수신한 스테이징 데이터를 KV 풀의 정확한 위치에 분배한다.
@triton.jit
def _fused_scatter_from_staging_kernel(
layer_ptrs, page_indices, staging, writer_head_offsets,
num_tokens, stride_pool_token, per_layer_elems,
ELEMS_PER_TOKEN: tl.constexpr, PAGE_SIZE: tl.constexpr,
NUM_LAYERS_X2: tl.constexpr, BLOCK_SIZE: tl.constexpr,
):
prog_id = tl.program_id(0)
writer_id = prog_id // NUM_LAYERS_X2
layer_kv_id = prog_id % NUM_LAYERS_X2
head_offset = tl.load(writer_head_offsets + writer_id)
per_rank_elems = per_layer_elems * NUM_LAYERS_X2
src_offsets = writer_id * per_rank_elems + layer_kv_id * per_layer_elems + offsets
vals = tl.load(staging + src_offsets, mask=mask)
dst_offsets = pool_token * stride_pool_token + head_offset + e_idx
tl.store(layer_ptr + dst_offsets, vals, mask=mask)
Head Slice 파라미터 계산
src와 dst의 TP 크기가 다를 때 어떤 head 범위를 전송해야 하는지 계산한다.
def compute_head_slice_params(
src_attn_tp_size, dst_attn_tp_size,
src_tp_rank, dst_tp_rank, total_kv_heads,
) -> Tuple[int, int, int, int]:
src_heads_per_rank = max(1, total_kv_heads // src_attn_tp_size)
dst_heads_per_rank = max(1, total_kv_heads // dst_attn_tp_size)
if src_attn_tp_size > dst_attn_tp_size:
src_head_start = 0
num_heads_to_send = src_heads_per_rank
dst_head_start = (unique_head_idx * src_heads_per_rank) % dst_heads_per_rank
else:
src_head_start = (dst_tp_rank * dst_heads_per_rank) % src_heads_per_rank
num_heads_to_send = dst_heads_per_rank
dst_head_start = 0
return src_head_start, num_heads_to_send, dst_head_start, num_heads_to_send
Dispatch: Torch vs Triton
환경 변수로 Triton 커널과 torch.gather 폴백 중 선택할 수 있다.
_USE_TRITON_STAGING = not bool(os.environ.get("SGLANG_STAGING_USE_TORCH", ""))
def gather_all_layers_to_staging(...):
if _USE_TRITON_STAGING:
return _gather_all_layers_triton(...)
return _gather_all_layers_torch(...)
설계 근거
왜 Staging Buffer가 필요한가?
TP=4 Prefill에서 TP=2 Decode로 전송할 때, 각 Prefill rank는 전체 head의 1/4만 가지고 있다. Staging Buffer 없이는 각 토큰, 각 레이어마다 개별 RDMA 요청이 필요하여 네트워크 오버헤드가 심각해진다. Staging Buffer로 모든 레이어를 연속 메모리에 모은 뒤 단일 벌크 전송으로 처리한다.
Ring Buffer + Overcommit
Decode 측의 StagingAllocator는 할당 시점에는 항상 성공하고(overcommit), 실제 안전성은 watermark를 통해 Prefill 측에서 검증한다. 이를 통해 할당 실패로 인한 지연을 방지한다.
관련 포스트
참고
관련 포스트
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
- [sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서
- [sglang] NixlKVManager 성능 향상: 비동기 및 멀티스레드 KV 전송 도입
SGLang 의 다른글
- 이전글 [SGLang] Disaggregation 커넥터: Mooncake, NIXL, MORI 전송 엔진
- 현재글 : [SGLang] Staging Buffer: KV 캐시 전송 버퍼 관리
- 다음글 [SGLang] LoRA Manager: 어댑터 라이프사이클 관리
댓글