본문으로 건너뛰기

[SGLang] Allocator: 토큰-KV 풀 할당 전략의 설계

들어가며

Memory Pool이 GPU 메모리의 물리적 공간을 관리한다면, Allocator는 "어떤 슬롯을 어떤 요청에 할당할 것인가"라는 논리적 할당 전략을 담당한다. SGLang은 두 가지 Allocator를 제공한다. 토큰 단위의 TokenToKVPoolAllocator와 페이지 단위의 PagedTokenToKVPoolAllocator이다.

이 글에서는 python/sglang/srt/mem_cache/allocator.py를 중심으로 Allocator의 설계를 분석한다.

두 가지 Allocator

SGLang의 할당기는 공통 인터페이스를 공유하되, 할당 단위가 다르다.

┌──────────────────────────────────────────────────────────┐
│  TokenToKVPoolAllocator (page_size=1)                   │
│  ┌─┬─┬─┬─┬─┬─┬─┬─┬─┬─┐                                │
│  │123456789│…│  토큰 단위 할당                   │
│  └─┴─┴─┴─┴─┴─┴─┴─┴─┴─┘                                │
├──────────────────────────────────────────────────────────┤
│  PagedTokenToKVPoolAllocator (page_size=16)             │
│  ┌────────────┬────────────┬────────────┐               │
│  │  Page 1Page 2Page 3    │               │
│  │ [0..15][16..31][32..47]   │ 페이지 단위   │
│  └────────────┴────────────┴────────────┘               │
└──────────────────────────────────────────────────────────┘

BaseTokenToKVPoolAllocator: 공통 인터페이스

두 Allocator의 공통 기반 클래스이다.

class BaseTokenToKVPoolAllocator(abc.ABC):
    def __init__(self, size, page_size, dtype, device, kvcache, need_sort):
        self.size = size
        self.page_size = page_size
        self.device = device
        self._kvcache = kvcache
        self.need_sort = need_sort
        self.free_pages = None
        self.release_pages = None
        self.is_not_in_free_group = True
        self.free_group = []

    def available_size(self):
        return (len(self.free_pages) + len(self.release_pages)) * self.page_size

free_pagesrelease_pages를 분리한 것이 핵심이다. need_sort=True일 때, 해제된 페이지는 즉시 free_pages에 합치지 않고 release_pages에 모아두었다가 필요할 때 정렬 후 병합한다. 이는 연속 할당이 필요한 시나리오에서 단편화를 방지한다.

def merge_and_sort_free(self):
    if len(self.release_pages) > 0:
        self.free_pages = torch.cat((self.free_pages, self.release_pages))
        self.free_pages, _ = torch.sort(self.free_pages)
        self.release_pages = torch.empty(
            (0,), dtype=self.release_pages.dtype, device=self.device)

TokenToKVPoolAllocator: 토큰 단위 할당

가장 단순한 할당기이다. page_size=1로 동작한다.

class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
    def __init__(self, size, dtype, device, kvcache, need_sort):
        super().__init__(size, 1, dtype, device, kvcache, need_sort)
        self.clear()

    def clear(self):
        # 슬롯 0은 패딩용 더미 출력에 사용
        self.free_pages = torch.arange(
            1, self.size + 1, dtype=torch.int64, device=self.device)
        self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)

    def alloc(self, need_size: int):
        if self.need_sort and need_size > len(self.free_pages):
            self.merge_and_sort_free()
        if need_size > len(self.free_pages):
            return None
        select_index = self.free_pages[:need_size]
        self.free_pages = self.free_pages[need_size:]
        return select_index

슬롯 0을 예약하는 것이 중요하다. Attention 연산에서 패딩 토큰의 출력을 슬롯 0에 쓰면, 실제 데이터가 손상되지 않는다.

free 메서드는 free_group 메커니즘을 지원한다.

def free(self, free_index: torch.Tensor):
    if free_index.numel() == 0:
        return
    if self.is_not_in_free_group:
        if self.need_sort:
            self.release_pages = torch.cat((self.release_pages, free_index))
        else:
            self.free_pages = torch.cat((self.free_pages, free_index))
    else:
        self.free_group.append(free_index)

free_group은 배치 처리를 위한 최적화이다. 여러 요청의 해제를 모아서 한 번에 처리하면 torch.cat 호출 횟수를 줄인다.

def free_group_begin(self):
    self.is_not_in_free_group = False
    self.free_group = []

def free_group_end(self):
    self.is_not_in_free_group = True
    if self.free_group:
        self.free(torch.cat(self.free_group))

PagedTokenToKVPoolAllocator: 페이지 단위 할당

page_size > 1일 때 사용하는 할당기이다. Triton 커널로 prefill과 decode의 인덱스 계산을 GPU에서 직접 수행한다.

class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
    def __init__(self, size, page_size, dtype, device, kvcache, need_sort):
        super().__init__(size, page_size, dtype, device, kvcache, need_sort)
        self.num_pages = size // page_size

alloc_extend: Prefill 할당

Prefill 단계에서 여러 요청의 확장 토큰을 한 번에 할당한다.

def alloc_extend(self, prefix_lens, prefix_lens_cpu, seq_lens,
                 seq_lens_cpu, last_loc, extend_num_tokens):
    bs = len(prefix_lens)
    out_indices = torch.empty(
        (extend_num_tokens,), dtype=torch.int64, device=self.device)

    alloc_extend_kernel[(bs,)](
        prefix_lens, seq_lens, last_loc, self.free_pages,
        out_indices, next_power_of_2(bs), self.page_size,
    )
    num_new_pages = get_num_new_pages(
        seq_lens=seq_lens_cpu, page_size=self.page_size,
        prefix_lens=prefix_lens_cpu)
    self.free_pages = self.free_pages[num_new_pages:]
    return out_indices

Triton 커널 alloc_extend_kernel은 세 부분으로 나누어 인덱스를 계산한다.

┌─────────────┬────────────────────────┬──────────────┐
│   Part 1    │       Part 2           │    Part 3    │
│ 기존 부분   │    새 전체 페이지       │  새 부분     │
│ 페이지 나머지│    (full pages)        │  페이지 시작 │
└─────────────┴────────────────────────┴──────────────┘
  prefix 잔여    새로 할당된 페이지들     마지막 부분 페이지

Part 1은 이전 페이지의 남은 공간을 채우고, Part 2는 새 전체 페이지를 할당하며, Part 3은 마지막 부분 페이지를 처리한다.

alloc_decode: Decode 할당

Decode 단계에서는 각 요청에 1토큰씩만 할당하면 된다.

def alloc_decode(self, seq_lens, seq_lens_cpu, last_loc):
    bs = len(seq_lens)
    out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
    alloc_decode_kernel[(bs,)](
        seq_lens, last_loc, self.free_pages, out_indices,
        next_power_of_2(bs), self.page_size,
    )

Triton 커널은 새 페이지가 필요한지 판단한다. 현재 페이지에 공간이 남아있으면 last_loc + 1을 사용하고, 페이지 경계를 넘으면 새 페이지의 첫 번째 슬롯을 할당한다.

# Triton 커널의 핵심 로직
if num_page_start_loc_self == 0:
    last_loc = tl.load(last_loc_ptr + pid)
    tl.store(out_indices + pid, last_loc + 1)
else:
    page = tl.load(free_page_ptr + new_page_start_loc)
    tl.store(out_indices + pid, page * page_size)

free: 페이지 단위 해제

페이지 단위 할당기의 free는 인덱스를 페이지 번호로 변환한 후 중복을 제거한다.

def free(self, free_index: torch.Tensor):
    if free_index.numel() == 0:
        return
    if self.is_not_in_free_group:
        free_page_indices = torch.unique(free_index // self.page_size)
        if self.need_sort:
            self.release_pages = torch.cat((free_page_indices, self.release_pages))
        else:
            self.free_pages = torch.cat((free_page_indices, self.free_pages))
    else:
        self.free_group.append(free_index)

torch.unique로 같은 페이지에 속하는 여러 토큰 인덱스를 하나의 페이지 번호로 합치는 것이 핵심이다.

설계 근거

Triton 커널을 사용하는 이유: Prefill 단계에서 수백 개 요청의 인덱스를 동시에 계산해야 한다. Python 루프로 처리하면 요청 수에 비례하는 CPU 오버헤드가 발생한다. Triton 커널은 각 요청을 GPU 블록 하나에 매핑하여 병렬 처리한다.

free_pages/release_pages 분리의 이유: 해제된 인덱스를 즉시 free_pages에 합치면, 매 해제마다 정렬이 필요해진다. release_pages에 모아두었다가 할당이 부족할 때만 정렬/병합하면, 불필요한 정렬 연산을 줄인다.

page_size의 트레이드오프: page_size가 크면 할당/해제 오버헤드가 줄지만 내부 단편화(사용하지 않는 슬롯)가 증가한다. page_size=1은 단편화가 없지만 관리 비용이 크다. SGLang은 이를 사용자가 선택할 수 있게 두 Allocator를 모두 제공한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글