본문으로 건너뛰기

[SGLang] Staging Buffer: KV 캐시 전송 버퍼 관리

들어가며

Prefill과 Decode 서버의 TP(Tensor Parallel) 크기가 다른 heterogeneous TP 환경에서, KV 캐시를 토큰 단위로 RDMA 전송하면 O(tokens * layers) 개의 소규모 요청이 발생한다. SGLang의 Staging Buffer는 흩어진 head 슬라이스를 연속 GPU 메모리로 모은 뒤 벌크 RDMA 전송을 수행하여, 요청 수를 O(layers) 또는 O(1)로 줄인다.

구조도

Prefill 서버 (TP=4)                    Decode 서버 (TP=2)
┌─────────────────────┐               ┌─────────────────────┐
│ KV Pool             │               │ KV Pool             │
│ [head0][head1]...   │               │ [head0~1][head2~3]  │
│  분산된 head 슬라이스 │               │                     │
└────────┬────────────┘               └──────────▲──────────┘
         │ gather                               │ scatter
         ▼                                      │
┌─────────────────────┐    RDMA     ┌──────────┴──────────┐
│ Staging Buffer      │ ─────────► │ Staging Allocator    │
│ (Prefill측, 연속)    │  bulk xfer │ (Decode측, ring buf) │
└─────────────────────┘            └─────────────────────┘

핵심 코드 분석

StagingBuffer: GPU 스테이징 메모리

python/sglang/srt/disaggregation/common/staging_buffer.pyStagingBuffer는 벌크 전송을 위한 연속 GPU 메모리를 할당한다.

class StagingBuffer:
    def __init__(self, size_bytes, device, gpu_id, custom_mem_pool=None):
        torch.cuda.set_device(gpu_id)
        if custom_mem_pool is not None:
            with torch.cuda.use_mem_pool(custom_mem_pool):
                self.buffer = torch.empty(size_bytes, dtype=torch.uint8, device=device)
        else:
            self.buffer = torch.empty(size_bytes, dtype=torch.uint8, device=device)
        self.data_ptr = self.buffer.data_ptr()

custom_mem_pool이 제공되면 cuMemCreate 기반 메모리 풀에서 할당하여 NVLink/MNNVL 전송과 호환되도록 한다.

StagingAllocator: Ring Buffer 방식 할당

Decode 측에서는 Ring Buffer 방식의 동적 할당기를 사용한다. 오버커밋을 허용하고, watermark로 안전 영역을 추적한다.

class StagingAllocator:
    ALLOC_OVERSIZED = -2

    def __init__(self, total_size_bytes, device, gpu_id, custom_mem_pool=None):
        self.buffer = StagingBuffer(total_size_bytes, device, gpu_id, custom_mem_pool)
        self.total_size = total_size_bytes
        self.base_ptr = self.buffer.data_ptr
        self.head = 0
        self.round = 0
        self.allocations: dict = {}
        self.alloc_order: List[int] = []

할당과 해제

Ring Buffer에서 공간을 할당하고, 해제 시 watermark를 전진시킨다.

def assign(self, required_bytes: int) -> Optional[Tuple[int, int, int]]:
    with self.lock:
        if required_bytes > self.total_size:
            return None
        space_at_end = self.total_size - self.head
        if required_bytes <= space_at_end:
            offset = self.head
            self.head += required_bytes
        else:
            self.round += 1
            offset = 0
            self.head = required_bytes
        alloc_id = self.next_alloc_id
        self.next_alloc_id += 1
        self.allocations[alloc_id] = (offset, required_bytes, self.round)
        return (alloc_id, offset, self.round)

def free(self, alloc_id: int):
    with self.lock:
        self.allocations.pop(alloc_id)
        while self.alloc_order and self.alloc_order[0] not in self.allocations:
            self.alloc_order.pop(0)
        if not self.allocations:
            self.watermark_round = self.round
            self.watermark_tail = self.head

Watermark (round, tail_offset) 이전의 영역은 Prefill 서버가 안전하게 쓸 수 있다.

Triton Gather 커널: GPU 내 데이터 모으기

KV 풀에서 흩어진 head 슬라이스를 Staging Buffer로 모으는 Triton 커널이다.

@triton.jit
def _fused_gather_to_staging_kernel(
    layer_ptrs, page_indices, staging, num_tokens,
    stride_pool_token, head_offset, per_layer_elems,
    ELEMS_PER_TOKEN: tl.constexpr, PAGE_SIZE: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    layer_id = tl.program_id(0)
    block_id = tl.program_id(1)
    layer_ptr = tl.load(layer_ptrs + layer_id).to(staging.dtype)

    offsets = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    t_idx = offsets // ELEMS_PER_TOKEN
    e_idx = offsets % ELEMS_PER_TOKEN

    page_id = t_idx // PAGE_SIZE
    intra_page = t_idx % PAGE_SIZE
    page_val = tl.load(page_indices + page_id, mask=mask, other=0)
    pool_token = page_val * PAGE_SIZE + intra_page

    src_offsets = pool_token * stride_pool_token + head_offset + e_idx
    vals = tl.load(layer_ptr + src_offsets, mask=mask)
    dst_offsets = tl.program_id(0) * per_layer_elems + offsets
    tl.store(staging + dst_offsets, vals, mask=mask)

(layer_id, block_id) 2D 그리드로 모든 레이어의 K/V 버퍼를 단일 커널 호출로 처리한다.

Triton Scatter 커널: 스테이징에서 KV 풀로

Decode 측에서 수신한 스테이징 데이터를 KV 풀의 정확한 위치에 분배한다.

@triton.jit
def _fused_scatter_from_staging_kernel(
    layer_ptrs, page_indices, staging, writer_head_offsets,
    num_tokens, stride_pool_token, per_layer_elems,
    ELEMS_PER_TOKEN: tl.constexpr, PAGE_SIZE: tl.constexpr,
    NUM_LAYERS_X2: tl.constexpr, BLOCK_SIZE: tl.constexpr,
):
    prog_id = tl.program_id(0)
    writer_id = prog_id // NUM_LAYERS_X2
    layer_kv_id = prog_id % NUM_LAYERS_X2
    head_offset = tl.load(writer_head_offsets + writer_id)

    per_rank_elems = per_layer_elems * NUM_LAYERS_X2
    src_offsets = writer_id * per_rank_elems + layer_kv_id * per_layer_elems + offsets
    vals = tl.load(staging + src_offsets, mask=mask)
    dst_offsets = pool_token * stride_pool_token + head_offset + e_idx
    tl.store(layer_ptr + dst_offsets, vals, mask=mask)

Head Slice 파라미터 계산

src와 dst의 TP 크기가 다를 때 어떤 head 범위를 전송해야 하는지 계산한다.

def compute_head_slice_params(
    src_attn_tp_size, dst_attn_tp_size,
    src_tp_rank, dst_tp_rank, total_kv_heads,
) -> Tuple[int, int, int, int]:
    src_heads_per_rank = max(1, total_kv_heads // src_attn_tp_size)
    dst_heads_per_rank = max(1, total_kv_heads // dst_attn_tp_size)

    if src_attn_tp_size > dst_attn_tp_size:
        src_head_start = 0
        num_heads_to_send = src_heads_per_rank
        dst_head_start = (unique_head_idx * src_heads_per_rank) % dst_heads_per_rank
    else:
        src_head_start = (dst_tp_rank * dst_heads_per_rank) % src_heads_per_rank
        num_heads_to_send = dst_heads_per_rank
        dst_head_start = 0
    return src_head_start, num_heads_to_send, dst_head_start, num_heads_to_send

Dispatch: Torch vs Triton

환경 변수로 Triton 커널과 torch.gather 폴백 중 선택할 수 있다.

_USE_TRITON_STAGING = not bool(os.environ.get("SGLANG_STAGING_USE_TORCH", ""))

def gather_all_layers_to_staging(...):
    if _USE_TRITON_STAGING:
        return _gather_all_layers_triton(...)
    return _gather_all_layers_torch(...)

설계 근거

왜 Staging Buffer가 필요한가?

TP=4 Prefill에서 TP=2 Decode로 전송할 때, 각 Prefill rank는 전체 head의 1/4만 가지고 있다. Staging Buffer 없이는 각 토큰, 각 레이어마다 개별 RDMA 요청이 필요하여 네트워크 오버헤드가 심각해진다. Staging Buffer로 모든 레이어를 연속 메모리에 모은 뒤 단일 벌크 전송으로 처리한다.

Ring Buffer + Overcommit

Decode 측의 StagingAllocator는 할당 시점에는 항상 성공하고(overcommit), 실제 안전성은 watermark를 통해 Prefill 측에서 검증한다. 이를 통해 할당 실패로 인한 지연을 방지한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글