[SGLang] Multi-head Latent Attention (MLA): KV 캐시 압축 어텐션
들어가며
MHA(Multi-Head Attention)는 각 head마다 독립적인 K, V를 저장하므로 KV 캐시가 head 수에 비례하여 증가한다. GQA(Grouped Query Attention)는 KV head를 공유하여 이를 줄이지만, 표현력과 트레이드오프가 있다. DeepSeek-V2에서 제안된 MLA(Multi-head Latent Attention)는 다른 접근법을 택한다. 고차원 KV를 저차원 latent representation으로 압축하여 캐시에 저장하고, 어텐션 시점에 원래 차원으로 복원한다.
SGLang은 MLA를 위해 FlashInfer MLA, FlashMLA, CUTLASS MLA 세 가지 백엔드를 제공한다. 이 글에서는 이 세 가지 구현을 비교 분석한다.
MHA vs MLA: 원리 비교
MHA: 각 Head 독립 저장
Q (h heads x d_h) K (h heads x d_h) V (h heads x d_h)
│ │ │
▼ ▼ ▼
┌─────────┐ ┌─────────┐ ┌─────────┐
│ head 1 │ │ head 1 │ │ head 1 │
│ head 2 │ │ head 2 │ │ head 2 │
│ ... │ │ ... │ │ ... │
│ head h │ │ head h │ │ head h │
└─────────┘ └─────────┘ └─────────┘
KV Cache: h x d_h x 2 per token
MLA: Latent Compression + Absorption
Hidden State
│
├──── W_DKV ──── c_kv (d_c 차원, d_c << h x d_h)
│ │
│ ┌─────┴─────┐
│ │ │
│ W_UK → K W_UV → V
│ (복원, absorbed into Q)
│
├──── W_QKV ──── c_q (q latent)
│ │
│ W_UQ → Q
│
└──── RoPE ──── k_rope (d_rope 차원)
KV Cache: (d_c + d_rope) per token
= (512 + 64) = 576 vs MHA: 128 x 128 x 2 = 32768
MLA의 핵심 트릭은 "absorption"이다. 어텐션 연산에서 Q @ K^T를 계산할 때, K = W_UK @ c_kv이므로 Q @ (W_UK @ c_kv)^T = (Q @ W_UK^T) @ c_kv^T가 된다. W_UK를 Q 쪽에 흡수(absorb)하면 c_kv를 직접 KV 캐시로 사용할 수 있다.
MHA vs MLA 비교표
| 항목 | MHA (128 heads) | GQA (8 KV heads) | MLA (DeepSeek-V2) |
|---|---|---|---|
| KV 캐시 / 토큰 | 2 x 128 x 128 = 32KB | 2 x 8 x 128 = 2KB | 576B (d_c + d_rope) |
| 캐시 압축률 | 1x | 16x | ~57x |
| Q-K 표현력 | 각 head 독립 | KV head 공유 | Latent space에서 복원 |
| 추가 연산 | 없음 | 없음 | Absorption 행렬곱 |
| 처리량 (tokens/sec) | 1x (기준) | ~4x | ~7x (논문 기준) |
DeepSeek-V2 논문에 따르면, MLA는 GQA-4(4 KV groups) 대비 동등한 성능을 유지하면서 KV 캐시를 약 7배 줄인다. 캐시 크기 감소는 더 큰 배치 크기를 가능하게 하여 처리량이 증가한다.
FlashInfer MLA 백엔드
FlashInferMLAAttnBackend는 FlashInfer의 BatchMLAPagedAttentionWrapper를 사용한다.
class FlashInferMLAAttnBackend(AttentionBackend):
def __init__(self, model_runner, skip_prefill=False, ...):
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD", backend=self.fmha_backend
)
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
self.workspace_buffer, backend="auto",
)
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer, backend="auto"
)
Prefill에서는 ragged wrapper와 MLA paged wrapper를 조합한다. Decode에서는 BatchMLAPagedAttentionWrapper가 compressed KV cache에서 직접 어텐션을 계산한다.
Ragged Prefill vs MLA Paged Prefill
def forward_extend(self, q, k, v, layer, forward_batch, ...):
if self.forward_metadata.use_ragged:
# ragged: Q,K,V를 직접 사용 (absorption 전 상태)
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
if k_rope is not None:
k = torch.cat([k, k_rope], dim=-1)
o = self.prefill_wrapper_ragged.forward(
qall, k.to(q.dtype), v.to(q.dtype), causal=True,
)
else:
# paged: absorbed Q를 q_nope, q_rope로 분리
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
q, q_rope = qall[:, :, :layer.v_head_dim], qall[:, :, layer.v_head_dim:]
o = prefill_wrapper_paged.run(
q, q_rope,
k_buf[:, :, :layer.v_head_dim],
k_buf[:, :, layer.v_head_dim:],
)
Ragged 경로는 prefix가 없을 때 사용하며, MLA의 absorption을 적용하지 않고 일반 MHA처럼 처리한다. Paged 경로는 캐시에 저장된 compressed KV를 사용하며, q_nope와 q_rope를 분리하여 MLA 커널에 전달한다.
Decode: MLA Paged 어텐션
def forward_decode(self, q, k, v, layer, forward_batch, ...):
decode_wrapper = self.forward_metadata.decode_wrapper
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim)
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
o = decode_wrapper.run(
q_nope, q_rope,
k_buffer[:, :, :layer.v_head_dim],
k_buffer[:, :, layer.v_head_dim:],
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
FlashMLA 백엔드
FlashMLABackend는 FlashInferMLAAttnBackend를 상속하여 Decode만 오버라이드한다.
class FlashMLABackend(FlashInferMLAAttnBackend):
def init_forward_metadata(self, forward_batch):
if forward_batch.forward_mode.is_decode_or_idle():
block_kv_indices = torch.full(
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token, ..., block_kv_indices, ...,
)
mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
self.num_q_heads, 1,
)
FlashMLA는 PAGE_SIZE=64의 고정 페이지 크기를 사용한다. get_mla_metadata는 MLA 커널의 block 스케줄링 메타데이터를 계산한다. Prefill은 FlashInfer MLA에 위임하고, Decode만 FlashMLA 커널(flash_mla_with_kvcache)을 사용한다.
CUTLASS MLA 백엔드
class CutlassMLABackend(FlashInferMLAAttnBackend):
def __init__(self, model_runner, ...):
# CUTLASS MLA는 PAGE_SIZE=128만 지원
...
def init_forward_metadata(self, forward_batch):
if forward_batch.forward_mode.is_decode_or_idle():
workspace_size = cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
)
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
CUTLASS MLA는 NVIDIA의 CUTLASS 라이브러리 기반으로, PAGE_SIZE=128을 사용한다. cutlass_mla_decode를 호출하여 어텐션을 계산하며, workspace 버퍼를 미리 할당한다. Prefill은 FlashInfer MLA에 위임한다.
Chunked Prefix KV Cache
FlashInfer MLA 백엔드는 chunked prefix cache를 위한 FlashInferMhaChunkKVRunner를 제공한다.
class FlashInferMhaChunkKVRunner:
def __init__(self, model_runner, attn_backend):
self.chunk_ragged_wrappers = []
self.ragged_wrapper = attn_backend.prefill_wrapper_ragged
def forward(self, q, k, v, layer, forward_batch):
if forward_batch.attn_attend_prefix_cache:
chunk_idx = forward_batch.prefix_chunk_idx
wrapper = self.chunk_ragged_wrappers[chunk_idx]
o = wrapper.forward_return_lse(
q.view(...), k.view(...).to(q.dtype),
v.view(...).to(q.dtype), causal=False,
)
MLA 모델에서도 prefix cache는 MHA 형태로 저장될 수 있다. 이 경우 chunk 단위로 MHA 어텐션을 수행한 뒤 결과를 합산한다.
3종 백엔드 비교
| 항목 | FlashInfer MLA | FlashMLA | CUTLASS MLA |
|---|---|---|---|
| Prefill | 자체 구현 | FlashInfer 위임 | FlashInfer 위임 |
| Decode | BatchMLAPagedAttn | flash_mla_with_kvcache | cutlass_mla_decode |
| Page Size | 가변 | 64 | 128 |
| FP8 KV | 지원 | 지원 | 미지원 |
| CUDA Graph | 지원 | 지원 | 지원 |
| SM 요구 | SM80+ | SM80+ | SM90+ |
| 상속 관계 | Base | extends FlashInfer MLA | extends FlashInfer MLA |
FlashMLA와 CUTLASS MLA는 Decode 성능에 최적화된 백엔드다. Prefill은 FlashInfer MLA의 구현을 그대로 재사용하므로, 선택 기준은 Decode 성능과 하드웨어 호환성이다.
관련 포스트
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] FlashInfer: 래그드 텐서 어텐션 엔진
- 현재글 : [SGLang] Multi-head Latent Attention (MLA): KV 캐시 압축 어텐션
- 다음글 [SGLang] NSA (Narrow Sparse Attention): DeepSeek의 스파스 어텐션
댓글