[vLLM] Mamba (SSM): 선형 시간 복잡도 시퀀스 모델링
들어가며
Mamba는 Structured State Space Model(SSM) 기반의 시퀀스 모델로, Transformer의 O(n^2) 어텐션을 O(n) 선형 재귀로 대체한다. vLLM은 Mamba를 attention backend로 추상화하여 기존 PagedAttention 프레임워크 안에서 SSM 상태를 관리한다. 이 글에서는 mamba_attn.py의 메타데이터 구조와 상태 관리 전략을 분석한다.
소스 경로: vllm/v1/attention/backends/mamba_attn.py
논문: Mamba: Linear-Time Sequence Modeling with Selective State Spaces
핵심 구조/코드 분석
BaseMambaAttentionMetadata
@dataclass
class BaseMambaAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_reqs: int
# Prefill 전용 텐서
has_initial_states_p: torch.Tensor | None
query_start_loc_p: torch.Tensor | None
num_computed_tokens_p: torch.Tensor | None
state_indices_tensor_p: torch.Tensor | None
# Decode 전용 텐서
state_indices_tensor_d: torch.Tensor | None
query_start_loc_d: torch.Tensor | None
# 투기적 디코딩 지원
num_accepted_tokens: torch.Tensor | None
# 프리픽스 캐싱 지원
block_idx_last_scheduled_token: torch.Tensor | None
block_idx_first_scheduled_token_p: torch.Tensor | None
block_idx_last_computed_token: torch.Tensor | None
Transformer의 KV 캐시와 달리 Mamba는 고정 크기 상태(state)를 유지한다. state_indices_tensor가 각 요청의 상태가 저장된 블록 인덱스를 가리키며, prefill과 decode에 대해 별도로 관리된다.
MetadataBuilder 초기화
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec, layer_names, vllm_config, device):
self.num_spec_tokens = vllm_config.num_speculative_tokens
self.use_spec_decode = self.num_spec_tokens > 0
if vllm_config.cache_config.mamba_cache_mode == "all":
self.state_indices_tensor_d = torch.empty(
(self.decode_cudagraph_max_bs, max_num_blocks),
dtype=torch.int32, device=device,
)
else:
self.state_indices_tensor_d = torch.empty(
(self.decode_cudagraph_max_bs, 1 + self.num_spec_tokens),
dtype=torch.int32, device=device,
)
mamba_cache_mode에 따라 상태 인덱스 텐서의 크기가 달라진다:
- "all" 모드: 모든 블록에 대한 상태를 추적 (프리픽스 캐싱 지원)
- 기본 모드: 현재 상태 + 투기적 디코딩 토큰 수만큼만 추적
청크 메타데이터 계산
def _compute_chunk_metadata(self, chunk_size, num_prefills,
num_computed_tokens_p_cpu, query_start_loc_p_cpu):
"""청크별 메타데이터 생성. 두 가지 제약 조건 보장:
1. 각 청크는 단일 시퀀스의 토큰만 포함
2. chunk_size 토큰마다 mamba 상태를 확실히 추출 가능
"""
cu_chunk_seqlen = []
for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (query_start_loc_p_cpu[req_idx + 1].item()
- query_start_loc_p_cpu[req_idx].item())
# 이전 청크가 정렬되지 않은 경우 먼저 마무리
if this_num_computed % chunk_size != 0:
chunk_len = (cdiv(this_num_computed, chunk_size) * chunk_size
- this_num_computed)
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
Mamba prefill에서는 시퀀스를 chunk_size 단위로 분할하여 중간 상태를 저장한다. 이는 프리픽스 캐싱의 핵심이다 -- 청크 경계에서의 Mamba 상태를 캐싱하면, 이후 같은 프리픽스를 재계산하지 않아도 된다.
CUDA Graph 지원
def _update_metadata_for_cudagraph_capture(self, metadata):
if (metadata.num_prefills == 0
and metadata.num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()):
padded_bs = metadata.num_reqs
self.state_indices_tensor_d[:metadata.num_decodes].copy_(
state_indices_tensor_d, non_blocking=True)
state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
state_indices_tensor_d[metadata.num_decodes:] = NULL_BLOCK_ID
decode-only 배치에서 CUDA 그래프를 지원한다. 배치 크기를 패딩하고 남는 슬롯에 NULL_BLOCK_ID를 채워 정적 크기의 그래프를 재사용할 수 있게 한다.
왜 이 설계인가
-
Attention Backend 추상화: Mamba의 상태를 KV 캐시 프레임워크의 "블록"으로 매핑한다. 이를 통해 스케줄러와 메모리 관리 코드를 Transformer와 공유하면서도 SSM의 고유한 상태 관리를 지원한다.
-
Prefill/Decode 분리: SSM에서 prefill은 병렬 스캔이고 decode는 순차적 갱신이다. 메타데이터를 분리하여 각각에 최적화된 커널을 사용한다.
-
청크 기반 프리픽스 캐싱: Transformer에서는 토큰별로 KV를 캐싱하지만, Mamba에서는
chunk_size간격으로 상태를 저장한다. 단일 시퀀스만 포함하는 청크 제약은 Mamba 커널을 크게 단순화한다. -
투기적 디코딩 호환:
num_accepted_tokens텐서로 각 요청의 수락된 토큰 수를 추적하여, 올바른 상태 체크포인트에서 복원할 수 있게 한다.
논문 핵심 내용
Mamba 논문의 핵심 기여는 Selective State Space Model(S6)을 도입하여, SSM의 입력 의존적(input-dependent) 선택 메커니즘을 효율적으로 구현한 것이다. 기존 S4 모델은 시간 불변(time-invariant) 파라미터를 사용했지만, Mamba의 S6는 입력에 따라 파라미터(Delta, B, C)를 동적으로 조절한다. 이 선택성 덕분에 관련 있는 정보만 상태에 남기고 불필요한 정보는 필터링할 수 있다.
Transformer 대비 언어 모델링 성능 (Zero-shot)
| 모델 | 크기 | Pile PPL | LAMBADA Acc | HellaSwag | PIQA | 평균 |
|---|---|---|---|---|---|---|
| Pythia | 160M | 29.64 | 33.0% | 30.2% | 61.4% | 40.6% |
| Mamba | 130M | 10.56 | 44.3% | 35.3% | 64.5% | 44.7% |
| Pythia | 1B | 7.82 | 56.1% | 47.2% | 70.7% | 51.9% |
| Mamba | 790M | 7.33 | 62.7% | 55.1% | 72.1% | 57.1% |
| Mamba | 2.8B | 6.22 | 69.2% | 66.1% | 75.2% | 63.3% |
Mamba-2.8B는 동일 크기의 Transformer를 크게 앞서고, 2배 크기의 Transformer와 동등한 성능을 보인다.
추론 처리량
Mamba는 추론 시 Transformer 대비 4-5배 높은 처리량을 달성한다. 시퀀스 길이 2K 이상에서는 FlashAttention-2보다도 빠르다. 선형 시간 복잡도 덕분에 시퀀스 길이가 길어질수록 이 격차는 더 벌어진다.
Selective SSM Ablation
| Delta 선택적 | B 선택적 | C 선택적 | Perplexity |
|---|---|---|---|
| X | X | X | 10.93 |
| O | X | X | 9.81 |
| O | O | O | 8.71 |
Delta, B, C를 모두 입력 의존적으로 만들었을 때 perplexity가 10.93에서 8.71로 대폭 개선된다. 특히 상태 차원을 1에서 16으로 늘리면 파라미터가 1%만 증가하는데 perplexity는 1.0 이상 개선되어, 상태 차원 확장의 효율이 매우 높다는 것을 확인할 수 있다.
다른 도메인 결과
오디오(Speech Generation SC09) 태스크에서 Mamba-6.1M은 FID 0.94, IS 6.26을 달성하여 기존 SaShiMi(FID 1.99, IS 5.13)를 크게 앞섰다. DNA 시퀀스 모델링에서도 동일 파라미터 수 대비 3-4배 적은 파라미터로 Transformer++와 HyenaDNA에 필적하는 성능을 보였다.
참고
관련 포스트
vLLM 의 다른글
- 이전글 [vLLM] N-gram & Suffix Decoding: 모델 프리 드래프트
- 현재글 : [vLLM] Mamba (SSM): 선형 시간 복잡도 시퀀스 모델링
- 다음글 [vLLM] Context Parallelism: 컨텍스트 병렬화
댓글