본문으로 건너뛰기

[vLLM] Mamba (SSM): 선형 시간 복잡도 시퀀스 모델링

들어가며

Mamba는 Structured State Space Model(SSM) 기반의 시퀀스 모델로, Transformer의 O(n^2) 어텐션을 O(n) 선형 재귀로 대체한다. vLLM은 Mamba를 attention backend로 추상화하여 기존 PagedAttention 프레임워크 안에서 SSM 상태를 관리한다. 이 글에서는 mamba_attn.py의 메타데이터 구조와 상태 관리 전략을 분석한다.

소스 경로: vllm/v1/attention/backends/mamba_attn.py

논문: Mamba: Linear-Time Sequence Modeling with Selective State Spaces

핵심 구조/코드 분석

BaseMambaAttentionMetadata

@dataclass
class BaseMambaAttentionMetadata:
    num_prefills: int
    num_prefill_tokens: int
    num_decodes: int
    num_decode_tokens: int
    num_reqs: int

    # Prefill 전용 텐서
    has_initial_states_p: torch.Tensor | None
    query_start_loc_p: torch.Tensor | None
    num_computed_tokens_p: torch.Tensor | None
    state_indices_tensor_p: torch.Tensor | None

    # Decode 전용 텐서
    state_indices_tensor_d: torch.Tensor | None
    query_start_loc_d: torch.Tensor | None

    # 투기적 디코딩 지원
    num_accepted_tokens: torch.Tensor | None

    # 프리픽스 캐싱 지원
    block_idx_last_scheduled_token: torch.Tensor | None
    block_idx_first_scheduled_token_p: torch.Tensor | None
    block_idx_last_computed_token: torch.Tensor | None

Transformer의 KV 캐시와 달리 Mamba는 고정 크기 상태(state)를 유지한다. state_indices_tensor가 각 요청의 상태가 저장된 블록 인덱스를 가리키며, prefill과 decode에 대해 별도로 관리된다.

MetadataBuilder 초기화

class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
    _cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

    def __init__(self, kv_cache_spec, layer_names, vllm_config, device):
        self.num_spec_tokens = vllm_config.num_speculative_tokens
        self.use_spec_decode = self.num_spec_tokens > 0

        if vllm_config.cache_config.mamba_cache_mode == "all":
            self.state_indices_tensor_d = torch.empty(
                (self.decode_cudagraph_max_bs, max_num_blocks),
                dtype=torch.int32, device=device,
            )
        else:
            self.state_indices_tensor_d = torch.empty(
                (self.decode_cudagraph_max_bs, 1 + self.num_spec_tokens),
                dtype=torch.int32, device=device,
            )

mamba_cache_mode에 따라 상태 인덱스 텐서의 크기가 달라진다:

  • "all" 모드: 모든 블록에 대한 상태를 추적 (프리픽스 캐싱 지원)
  • 기본 모드: 현재 상태 + 투기적 디코딩 토큰 수만큼만 추적

청크 메타데이터 계산

def _compute_chunk_metadata(self, chunk_size, num_prefills,
                            num_computed_tokens_p_cpu, query_start_loc_p_cpu):
    """청크별 메타데이터 생성. 두 가지 제약 조건 보장:
    1. 각 청크는 단일 시퀀스의 토큰만 포함
    2. chunk_size 토큰마다 mamba 상태를 확실히 추출 가능
    """
    cu_chunk_seqlen = []
    for req_idx in range(num_prefills):
        this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
        this_new_tokens = (query_start_loc_p_cpu[req_idx + 1].item()
                          - query_start_loc_p_cpu[req_idx].item())

        # 이전 청크가 정렬되지 않은 경우 먼저 마무리
        if this_num_computed % chunk_size != 0:
            chunk_len = (cdiv(this_num_computed, chunk_size) * chunk_size
                        - this_num_computed)
            chunk_len = min(chunk_len, this_new_tokens)
            seqlen_pos += chunk_len
            this_new_tokens -= chunk_len

Mamba prefill에서는 시퀀스를 chunk_size 단위로 분할하여 중간 상태를 저장한다. 이는 프리픽스 캐싱의 핵심이다 -- 청크 경계에서의 Mamba 상태를 캐싱하면, 이후 같은 프리픽스를 재계산하지 않아도 된다.

CUDA Graph 지원

def _update_metadata_for_cudagraph_capture(self, metadata):
    if (metadata.num_prefills == 0
        and metadata.num_decodes <= self.decode_cudagraph_max_bs
        and self.compilation_config.cudagraph_mode.has_full_cudagraphs()):

        padded_bs = metadata.num_reqs
        self.state_indices_tensor_d[:metadata.num_decodes].copy_(
            state_indices_tensor_d, non_blocking=True)
        state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
        state_indices_tensor_d[metadata.num_decodes:] = NULL_BLOCK_ID

decode-only 배치에서 CUDA 그래프를 지원한다. 배치 크기를 패딩하고 남는 슬롯에 NULL_BLOCK_ID를 채워 정적 크기의 그래프를 재사용할 수 있게 한다.

왜 이 설계인가

  1. Attention Backend 추상화: Mamba의 상태를 KV 캐시 프레임워크의 "블록"으로 매핑한다. 이를 통해 스케줄러와 메모리 관리 코드를 Transformer와 공유하면서도 SSM의 고유한 상태 관리를 지원한다.

  2. Prefill/Decode 분리: SSM에서 prefill은 병렬 스캔이고 decode는 순차적 갱신이다. 메타데이터를 분리하여 각각에 최적화된 커널을 사용한다.

  3. 청크 기반 프리픽스 캐싱: Transformer에서는 토큰별로 KV를 캐싱하지만, Mamba에서는 chunk_size 간격으로 상태를 저장한다. 단일 시퀀스만 포함하는 청크 제약은 Mamba 커널을 크게 단순화한다.

  4. 투기적 디코딩 호환: num_accepted_tokens 텐서로 각 요청의 수락된 토큰 수를 추적하여, 올바른 상태 체크포인트에서 복원할 수 있게 한다.

논문 핵심 내용

Mamba 논문의 핵심 기여는 Selective State Space Model(S6)을 도입하여, SSM의 입력 의존적(input-dependent) 선택 메커니즘을 효율적으로 구현한 것이다. 기존 S4 모델은 시간 불변(time-invariant) 파라미터를 사용했지만, Mamba의 S6는 입력에 따라 파라미터(Delta, B, C)를 동적으로 조절한다. 이 선택성 덕분에 관련 있는 정보만 상태에 남기고 불필요한 정보는 필터링할 수 있다.

Transformer 대비 언어 모델링 성능 (Zero-shot)

모델 크기 Pile PPL LAMBADA Acc HellaSwag PIQA 평균
Pythia 160M 29.64 33.0% 30.2% 61.4% 40.6%
Mamba 130M 10.56 44.3% 35.3% 64.5% 44.7%
Pythia 1B 7.82 56.1% 47.2% 70.7% 51.9%
Mamba 790M 7.33 62.7% 55.1% 72.1% 57.1%
Mamba 2.8B 6.22 69.2% 66.1% 75.2% 63.3%

Mamba-2.8B는 동일 크기의 Transformer를 크게 앞서고, 2배 크기의 Transformer와 동등한 성능을 보인다.

추론 처리량

Mamba는 추론 시 Transformer 대비 4-5배 높은 처리량을 달성한다. 시퀀스 길이 2K 이상에서는 FlashAttention-2보다도 빠르다. 선형 시간 복잡도 덕분에 시퀀스 길이가 길어질수록 이 격차는 더 벌어진다.

Selective SSM Ablation

Delta 선택적 B 선택적 C 선택적 Perplexity
X X X 10.93
O X X 9.81
O O O 8.71

Delta, B, C를 모두 입력 의존적으로 만들었을 때 perplexity가 10.93에서 8.71로 대폭 개선된다. 특히 상태 차원을 1에서 16으로 늘리면 파라미터가 1%만 증가하는데 perplexity는 1.0 이상 개선되어, 상태 차원 확장의 효율이 매우 높다는 것을 확인할 수 있다.

다른 도메인 결과

오디오(Speech Generation SC09) 태스크에서 Mamba-6.1M은 FID 0.94, IS 6.26을 달성하여 기존 SaShiMi(FID 1.99, IS 5.13)를 크게 앞섰다. DNA 시퀀스 모델링에서도 동일 파라미터 수 대비 3-4배 적은 파라미터로 Transformer++와 HyenaDNA에 필적하는 성능을 보였다.

참고

댓글

관련 포스트

vLLM 의 다른글