본문으로 건너뛰기

[SGLang] Mamba (SSM): 선형 시간 복잡도 시퀀스 모델링

들어가며

Transformer의 Self-Attention은 시퀀스 길이 n에 대해 O(n²) 시간과 메모리를 사용한다. 시퀀스가 길어질수록 비용이 급격히 증가한다. Mamba는 State Space Model(SSM)을 기반으로 O(n) 선형 복잡도를 달성하면서도 Transformer에 필적하는 성능을 보여주는 아키텍처이다.

SGLang은 Mamba2를 서빙 엔진에 통합하여, Continuous Batching과 Chunked Prefill 환경에서 SSM 상태를 효율적으로 관리한다. 이 글에서는 python/sglang/srt/layers/attention/mamba/ 디렉토리의 mamba.pymamba2_metadata.py를 분석한다.

Transformer vs Mamba 비교

특성 Transformer Mamba (SSM)
시간 복잡도 O(n²) O(n)
메모리 복잡도 O(n²) 또는 O(n) (FlashAttn) O(1) 상태 크기
병렬 학습 완전 병렬 청크 기반 병렬
추론 (Decode) KV 캐시 필요 (O(n) 메모리) 고정 크기 상태
긴 시퀀스 비용 급증 선형 확장
전역 의존성 직접 참조 상태 압축으로 간접 참조

구조도

┌─────────────────────────────────────────────────┐
│                  MambaMixer2                     │
│                                                 │
│  hidden_states                                  │
│       │                                         │
│       ▼                                         │
│  ┌──────────┐                                   │
│  │ in_proj  │ → gate, hidden_states_B_C, dt     │
│  └──────────┘                                   │
│       │                                         │
│       ▼                                         │
│  ┌──────────────┐     ┌────────────────┐        │
│  │ causal_conv1d│     │ conv_state     │        │
│  │  (depthwise) │ ←──▶│ (캐시 관리)    │        │
│  └──────┬───────┘     └────────────────┘        │
│         │                                       │
│         ▼                                       │
│  ┌──────────────────────────┐  ┌─────────────┐  │
│  │ mamba_chunk_scan_combined│  │ ssm_state   │  │
│  │  (Selective Scan)        │──▶ (상태 관리)  │  │
│  └──────────┬───────────────┘  └─────────────┘  │
│             │                                   │
│             ▼                                   │
│  ┌──────────────┐                               │
│  │  norm + gate │ → out_proj → output           │
│  └──────────────┘                               │
└─────────────────────────────────────────────────┘

핵심 코드 분석

MambaMixer2: SSM 레이어 구조

class MambaMixer2(torch.nn.Module):
    def __init__(self, cache_params, hidden_size, ...):
        self.num_heads = num_heads = cache_params.shape.num_heads
        self.head_dim = cache_params.shape.head_dim
        self.ssm_state_size = cache_params.shape.ssm_state_size
        self.intermediate_size = cache_params.shape.intermediate_size

        self.conv1d = MergedColumnParallelLinear(
            input_size=conv_kernel_size,
            output_sizes=[
                intermediate_size,
                self.groups_ssm_state_size,
                self.groups_ssm_state_size,
            ],
            ...
        )

        self.in_proj = MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[
                intermediate_size,      # gate
                intermediate_size,      # x (hidden states)
                self.groups_ssm_state_size,  # B
                self.groups_ssm_state_size,  # C
                self.num_heads,         # dt (delta time)
            ],
            ...
        )

in_proj는 하나의 행렬 곱으로 5개의 출력을 생성한다: gate, hidden states, B (입력 행렬), C (출력 행렬), dt (시간 스텝). Mamba의 핵심인 Selective Scan에서 B, C, dt가 입력에 의존적(input-dependent)이라는 점이 고전적 SSM과의 차이이다.

Tensor Parallel 샤딩: 그룹 복제 전략

if n_groups % self.tp_size != 0:
    groups = extra_groups_for_head_shards(n_groups, self.tp_size)
    self.n_groups = n_groups + groups

TP(Tensor Parallelism) 적용 시 num_heads는 TP 크기로 나누어 분배한다. 문제는 n_groups가 TP 크기로 나누어지지 않을 때이다. 예를 들어 n_groups=1이면 모든 헤드가 같은 그룹의 B, C를 공유해야 한다. SGLang은 이 경우 그룹을 복제(replicate)하여 각 TP 샤드가 필요한 그룹을 갖도록 extra_groups_for_head_shards로 추가 공간을 할당한다.

Forward: Prefill과 Decode 분리 처리

def forward(self, *, hidden_states, output, layer_cache, metadata, ...):
    # 1. Linear projection
    projected_states, _ = self.in_proj(hidden_states)
    gate, hidden_states_B_C, dt = torch.split(projected_states, [...], dim=-1)

    # Prefill과 Decode 토큰 분리
    hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
        hidden_states_B_C, [num_prefill_tokens, num_decode_tokens], dim=0,
    )

    # Prefill: causal_conv1d_fn으로 전체 시퀀스 처리
    if has_prefill:
        hidden_states_B_C_p = causal_conv1d_fn(
            x, conv_weights, self.conv1d.bias,
            activation=self.activation,
            conv_states=conv_state,
            has_initial_state=has_initial_states_p,
            cache_indices=cache_indices,
            query_start_loc=query_start_loc_p,
        ).transpose(0, 1)[:num_prefill_tokens]

하나의 배치에 Prefill 요청과 Decode 요청이 섞여 있을 수 있다. SGLang은 토큰 차원에서 Prefill과 Decode를 분리한 후 각각 다른 경로로 처리한다. Prefill은 causal_conv1d_fn으로 전체 시퀀스를 한 번에 Convolution하고, Decode는 causal_conv1d_update로 한 토큰씩 상태를 업데이트한다.

Selective Scan: 청크 기반 병렬 처리

varlen_state = mamba_chunk_scan_combined(
    hidden_states_p.view(
        1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
    ),
    dt_p.unsqueeze(0),
    self.A,
    B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
    C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
    chunk_size=mixed_metadata.chunk_size,
    D=self.D,
    dt_bias=self.dt_bias,
    seq_idx=mixed_metadata.seq_idx,
    chunk_indices=mixed_metadata.chunk_indices,
    chunk_offsets=mixed_metadata.chunk_offsets,
    cu_seqlens=query_start_loc_p,
    ...
)

Mamba의 Selective Scan은 순차적 연산이지만, 청크 단위로 나누면 청크 내부는 행렬 곱으로 병렬화할 수 있다. mamba_chunk_scan_combined는 이 청크 병렬화된 Selective Scan을 수행한다. seq_idx는 Variable Length 배치에서 시퀀스 경계를 추적하고, chunk_indiceschunk_offsets는 물리 청크와 논리 청크의 매핑을 관리한다.

Mamba2Metadata: 청크 인덱스 사전 계산

@staticmethod
def _query_start_loc_to_chunk_indices_offsets(
    query_start_loc, chunk_size, total_seqlens
):
    # 논리 청크: 시퀀스 경계와 물리 청크 경계가 교차하는 지점에서 분할
    N = (
        math.ceil(total_seqlens / chunk_size)
        + (cu_seqlens[:-1] % chunk_size > 0).sum()
    )
    chunk_indices = torch.arange(N, dtype=torch.int, device=query_start_loc.device)
    chunk_offsets = torch.zeros((N,), dtype=torch.int, device=query_start_loc.device)

    p = 0  # 삽입 횟수
    for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
        p += s % chunk_size > 0
        _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)
        chunk_indices[_s:_e] -= p
        chunk_offsets[_s] = s % chunk_size

    return chunk_indices, chunk_offsets

Variable Length 배치에서 여러 시퀀스가 연속으로 이어붙여진다. 물리 청크(chunk_size 단위)와 시퀀스 경계가 일치하지 않으면 한 물리 청크 안에 두 시퀀스의 토큰이 섞일 수 있다. chunk_indiceschunk_offsets는 이 경계를 추적하여 Selective Scan이 시퀀스를 넘어가지 않도록 보장한다. 이 메타데이터는 모델 forward 최상위에서 한 번 계산되어 모든 Mamba 레이어에서 재사용된다.

상태 관리: Conv State + SSM State

state_indices_tensor = metadata.mamba_cache_indices
conv_state = layer_cache.conv[0]     # [pool_size, conv_dim, conv_kernel]
ssm_state = layer_cache.temporal     # [pool_size, num_heads, head_dim, ssm_state_size]

Mamba는 KV 캐시 대신 두 가지 상태를 유지한다. conv_state는 Causal Convolution의 슬라이딩 윈도우 상태이고, ssm_state는 Selective Scan의 재귀 상태이다. 두 상태 모두 고정 크기이므로, 시퀀스가 길어져도 메모리 사용량이 증가하지 않는다. 이것이 Mamba가 긴 시퀀스에서 유리한 핵심 이유이다.

설계 근거: Continuous Batching에서의 SSM 상태 관리

Transformer의 KV 캐시는 토큰별로 독립적이라 Continuous Batching에 자연스럽다. 반면 SSM 상태는 재귀적이므로, 요청이 추가/제거될 때 상태를 올바르게 복원해야 한다. SGLang은 mamba_cache_indices로 각 요청의 상태 슬롯을 추적하고, has_initial_state 플래그로 캐시 히트 시 이전 상태를 복원한다. Prefix Caching도 SSM 상태 단위로 지원되어, 동일 프리픽스를 공유하는 요청은 SSM 상태를 복제하여 재사용한다.

관련 포스트

  • Hybrid Attention: Dense-Sparse 동적 전환 전략
  • GDN (Gated Diagonal Net): 게이트 기반 선형 어텐션
  • FLA (Flashy Linear Attention): 청크 기반 선형 어텐션 연산

참고

댓글

관련 포스트

SGLang 의 다른글