본문으로 건너뛰기

[SGLang] Multi-head Latent Attention (MLA): KV 캐시 압축 어텐션

들어가며

MHA(Multi-Head Attention)는 각 head마다 독립적인 K, V를 저장하므로 KV 캐시가 head 수에 비례하여 증가한다. GQA(Grouped Query Attention)는 KV head를 공유하여 이를 줄이지만, 표현력과 트레이드오프가 있다. DeepSeek-V2에서 제안된 MLA(Multi-head Latent Attention)는 다른 접근법을 택한다. 고차원 KV를 저차원 latent representation으로 압축하여 캐시에 저장하고, 어텐션 시점에 원래 차원으로 복원한다.

SGLang은 MLA를 위해 FlashInfer MLA, FlashMLA, CUTLASS MLA 세 가지 백엔드를 제공한다. 이 글에서는 이 세 가지 구현을 비교 분석한다.

MHA vs MLA: 원리 비교

MHA: 각 Head 독립 저장

  Q (h heads x d_h)    K (h heads x d_h)    V (h heads x d_h)
       │                     │                    │
       ▼                     ▼                    ▼
  ┌─────────┐          ┌─────────┐          ┌─────────┐
  │ head 1  │          │ head 1  │          │ head 1  │
  │ head 2  │          │ head 2  │          │ head 2  │
  │  ...    │          │  ...    │          │  ...    │
  │ head h  │          │ head h  │          │ head h  │
  └─────────┘          └─────────┘          └─────────┘
                       KV Cache: h x d_h x 2 per token

MLA: Latent Compression + Absorption

  Hidden State
       │
       ├──── W_DKV ──── c_kv (d_c 차원, d_c << h x d_h)
       │                  │
       │            ┌─────┴─────┐
       │            │           │
       │         W_UK → K    W_UV → V
       │       (복원, absorbed into Q)
       │
       ├──── W_QKV ──── c_q (q latent)
       │                  │
       │               W_UQ → Q
       │
       └──── RoPE ──── k_rope (d_rope 차원)

  KV Cache: (d_c + d_rope) per token
  = (512 + 64) = 576  vs  MHA: 128 x 128 x 2 = 32768

MLA의 핵심 트릭은 "absorption"이다. 어텐션 연산에서 Q @ K^T를 계산할 때, K = W_UK @ c_kv이므로 Q @ (W_UK @ c_kv)^T = (Q @ W_UK^T) @ c_kv^T가 된다. W_UK를 Q 쪽에 흡수(absorb)하면 c_kv를 직접 KV 캐시로 사용할 수 있다.

MHA vs MLA 비교표

항목 MHA (128 heads) GQA (8 KV heads) MLA (DeepSeek-V2)
KV 캐시 / 토큰 2 x 128 x 128 = 32KB 2 x 8 x 128 = 2KB 576B (d_c + d_rope)
캐시 압축률 1x 16x ~57x
Q-K 표현력 각 head 독립 KV head 공유 Latent space에서 복원
추가 연산 없음 없음 Absorption 행렬곱
처리량 (tokens/sec) 1x (기준) ~4x ~7x (논문 기준)

DeepSeek-V2 논문에 따르면, MLA는 GQA-4(4 KV groups) 대비 동등한 성능을 유지하면서 KV 캐시를 약 7배 줄인다. 캐시 크기 감소는 더 큰 배치 크기를 가능하게 하여 처리량이 증가한다.

FlashInfer MLA 백엔드

FlashInferMLAAttnBackend는 FlashInfer의 BatchMLAPagedAttentionWrapper를 사용한다.

class FlashInferMLAAttnBackend(AttentionBackend):
    def __init__(self, model_runner, skip_prefill=False, ...):
        self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
            self.workspace_buffer, "NHD", backend=self.fmha_backend
        )
        self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
            self.workspace_buffer, backend="auto",
        )
        self.decode_wrapper = BatchMLAPagedAttentionWrapper(
            self.workspace_buffer, backend="auto"
        )

Prefill에서는 ragged wrapper와 MLA paged wrapper를 조합한다. Decode에서는 BatchMLAPagedAttentionWrapper가 compressed KV cache에서 직접 어텐션을 계산한다.

Ragged Prefill vs MLA Paged Prefill

def forward_extend(self, q, k, v, layer, forward_batch, ...):
    if self.forward_metadata.use_ragged:
        # ragged: Q,K,V를 직접 사용 (absorption 전 상태)
        qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
        if k_rope is not None:
            k = torch.cat([k, k_rope], dim=-1)
        o = self.prefill_wrapper_ragged.forward(
            qall, k.to(q.dtype), v.to(q.dtype), causal=True,
        )
    else:
        # paged: absorbed Q를 q_nope, q_rope로 분리
        k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
        q, q_rope = qall[:, :, :layer.v_head_dim], qall[:, :, layer.v_head_dim:]
        o = prefill_wrapper_paged.run(
            q, q_rope,
            k_buf[:, :, :layer.v_head_dim],
            k_buf[:, :, layer.v_head_dim:],
        )

Ragged 경로는 prefix가 없을 때 사용하며, MLA의 absorption을 적용하지 않고 일반 MHA처럼 처리한다. Paged 경로는 캐시에 저장된 compressed KV를 사용하며, q_nopeq_rope를 분리하여 MLA 커널에 전달한다.

Decode: MLA Paged 어텐션

def forward_decode(self, q, k, v, layer, forward_batch, ...):
    decode_wrapper = self.forward_metadata.decode_wrapper
    q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
    q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim)

    k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
    o = decode_wrapper.run(
        q_nope, q_rope,
        k_buffer[:, :, :layer.v_head_dim],
        k_buffer[:, :, layer.v_head_dim:],
    )
    return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

FlashMLA 백엔드

FlashMLABackendFlashInferMLAAttnBackend를 상속하여 Decode만 오버라이드한다.

class FlashMLABackend(FlashInferMLAAttnBackend):
    def init_forward_metadata(self, forward_batch):
        if forward_batch.forward_mode.is_decode_or_idle():
            block_kv_indices = torch.full(
                (bs, max_seqlen_pad), -1, dtype=torch.int32, device=device,
            )
            create_flashmla_kv_indices_triton[(bs,)](
                self.req_to_token, ..., block_kv_indices, ...,
            )
            mla_metadata, num_splits = get_mla_metadata(
                forward_batch.seq_lens.to(torch.int32),
                self.num_q_heads, 1,
            )

FlashMLA는 PAGE_SIZE=64의 고정 페이지 크기를 사용한다. get_mla_metadata는 MLA 커널의 block 스케줄링 메타데이터를 계산한다. Prefill은 FlashInfer MLA에 위임하고, Decode만 FlashMLA 커널(flash_mla_with_kvcache)을 사용한다.

CUTLASS MLA 백엔드

class CutlassMLABackend(FlashInferMLAAttnBackend):
    def __init__(self, model_runner, ...):
        # CUTLASS MLA는 PAGE_SIZE=128만 지원
        ...

    def init_forward_metadata(self, forward_batch):
        if forward_batch.forward_mode.is_decode_or_idle():
            workspace_size = cutlass_mla_get_workspace_size(
                max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
            )
            workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)

CUTLASS MLA는 NVIDIA의 CUTLASS 라이브러리 기반으로, PAGE_SIZE=128을 사용한다. cutlass_mla_decode를 호출하여 어텐션을 계산하며, workspace 버퍼를 미리 할당한다. Prefill은 FlashInfer MLA에 위임한다.

Chunked Prefix KV Cache

FlashInfer MLA 백엔드는 chunked prefix cache를 위한 FlashInferMhaChunkKVRunner를 제공한다.

class FlashInferMhaChunkKVRunner:
    def __init__(self, model_runner, attn_backend):
        self.chunk_ragged_wrappers = []
        self.ragged_wrapper = attn_backend.prefill_wrapper_ragged

    def forward(self, q, k, v, layer, forward_batch):
        if forward_batch.attn_attend_prefix_cache:
            chunk_idx = forward_batch.prefix_chunk_idx
            wrapper = self.chunk_ragged_wrappers[chunk_idx]
            o = wrapper.forward_return_lse(
                q.view(...), k.view(...).to(q.dtype),
                v.view(...).to(q.dtype), causal=False,
            )

MLA 모델에서도 prefix cache는 MHA 형태로 저장될 수 있다. 이 경우 chunk 단위로 MHA 어텐션을 수행한 뒤 결과를 합산한다.

3종 백엔드 비교

항목 FlashInfer MLA FlashMLA CUTLASS MLA
Prefill 자체 구현 FlashInfer 위임 FlashInfer 위임
Decode BatchMLAPagedAttn flash_mla_with_kvcache cutlass_mla_decode
Page Size 가변 64 128
FP8 KV 지원 지원 미지원
CUDA Graph 지원 지원 지원
SM 요구 SM80+ SM80+ SM90+
상속 관계 Base extends FlashInfer MLA extends FlashInfer MLA

FlashMLA와 CUTLASS MLA는 Decode 성능에 최적화된 백엔드다. Prefill은 FlashInfer MLA의 구현을 그대로 재사용하므로, 선택 기준은 Decode 성능과 하드웨어 호환성이다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글