[vLLM] Multi-head Latent Attention: KV 캐시를 압축하는 DeepSeek의 어텐션
들어가며
MHA(Multi-Head Attention)는 각 헤드마다 독립적인 KV를 저장하므로 KV 캐시가 크다. GQA(Grouped Query Attention)는 여러 Q 헤드가 KV 헤드를 공유하여 이를 줄이지만, **MLA(Multi-head Latent Attention)**는 더 급진적인 접근을 취한다. KV를 저차원 잠재 벡터(latent vector)로 압축하여 캐시에 저장하고, 어텐션 시 다시 복원하는 방식이다. DeepSeek-V2에서 제안되어 DeepSeek-V3에서 더 발전했다.
- 논문: DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model (arxiv 2405.04434)
- 공식 문서: https://docs.vllm.ai
공식 문서
vLLM 공식 문서: Attention Backends
핵심 구조/코드 분석
MLA 백엔드 아키텍처
vLLM은 MLA를 위한 다양한 백엔드를 vllm/v1/attention/backends/mla/ 디렉토리에서 관리한다:
mla/
├── flashmla.py # FlashMLA (NVIDIA Hopper/Blackwell 최적화)
├── flashinfer_mla.py # FlashInfer 기반 MLA
├── cutlass_mla.py # CUTLASS 기반 MLA
├── triton_mla.py # Triton 기반 MLA
├── flashmla_sparse.py # Sparse MLA
├── indexer.py # DeepSeek 전용 인덱서
└── sparse_utils.py # Sparse 유틸리티
모델(DeepSeek-V2/V3)과 하드웨어(Hopper, Blackwell, ROCm)에 따라 최적의 MLA 백엔드가 자동 선택된다.
FlashMLABackend: 하드웨어 요구사항
class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16, torch.bfloat16
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "float16", "bfloat16", "fp8", "fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes():
return [64] # MLA는 블록 크기 64 고정
@classmethod
def supports_compute_capability(cls, capability):
return capability.major in [9, 10] # Hopper, Blackwell만 지원
FlashMLA는 SM90(Hopper) 이상에서만 동작한다. 블록 크기가 64로 고정된 이유는 MLA의 잠재 벡터 차원과 타일 크기가 64에 최적화되어 있기 때문이다.
MLA의 MQA 디코딩 경로
MLA의 decode는 본질적으로 MQA(Multi-Query Attention)로 동작한다:
def _build_decode(self, block_table_tensor, seq_lens_device, ...):
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item()
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
scheduler_metadata, _ = get_mla_metadata(
seq_lens_device,
num_q_tokens_per_head_k,
1, # MQA for the decode path
is_fp8_kvcache=self.is_fp8_kvcache,
)
KV 캐시에는 압축된 잠재 벡터가 저장되어 있으므로, 모든 Q 헤드가 하나의 KV "헤드"를 공유하는 MQA 형태가 된다. 이것이 MLA의 메모리 효율의 핵심이다.
FlashMLAImpl: 어텐션 연산
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
can_return_lse_for_decode: bool = True
def forward_mqa(self, q, kv_c_and_k_pe_cache, attn_metadata, layer):
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q = torch.cat(q, dim=-1)
num_decodes = attn_metadata.num_decodes
q = reshape_query_for_spec_decode(q, num_decodes)
scheduler_metadata = attn_metadata.decode.scheduler_metadata
kv_c_and_k_pe_cache라는 이름이 MLA의 구조를 보여준다. kv_c는 압축된 KV 잠재 벡터, k_pe는 RoPE(Rotary Position Embedding)를 위한 키의 위치 임베딩 부분이다. MLA에서 위치 정보는 별도로 유지해야 하므로 이 두 가지가 결합되어 캐시에 저장된다.
FP8 MLA 지원
self.is_fp8_kvcache = is_quantized_kv_cache(
vllm_config.cache_config.cache_dtype
)
if self.is_fp8_kvcache:
tile_scheduler_metadata, num_splits = get_mla_metadata_dense_fp8(
seq_lens_device,
num_q_tokens_per_head_k,
1,
)
이미 압축된 MLA 잠재 벡터를 FP8로 양자화하면, 원본 MHA 대비 KV 캐시를 16배 이상 줄일 수 있다. FP8 경로는 별도의 타일 스케줄러 메타데이터를 사용한다.
왜 이 설계인가
-
극단적 KV 캐시 압축: MHA에서 각 헤드가 128차원 KV를 유지한다면, MLA는 모든 헤드가 공유하는 512차원 잠재 벡터 하나로 압축한다. 128개 헤드 기준 128×128×2 = 32K에서 512+64(PE) = 576으로 약 56배 압축이 가능하다.
-
다양한 하드웨어 대응: FlashMLA, FlashInfer MLA, Triton MLA, CUTLASS MLA 등 하드웨어별 최적화된 구현을 제공하여 Hopper, Blackwell, ROCm 등 다양한 환경에서 동작한다.
-
Speculative Decoding 호환:
reshape_query_for_spec_decode를 통해 여러 토큰을 동시에 처리하는 speculative decoding과도 호환된다. -
Sparse MLA로 추가 최적화: DeepSeek-V3의 sparse MLA는 어텐션 스코어가 높은 KV만 선별적으로 읽어서 디코딩 속도를 더 높인다.
MLA는 "KV 캐시가 서빙의 최대 병목"이라는 문제에 대한 모델 아키텍처 수준의 해결책이며, vLLM은 이를 다양한 커널 백엔드로 효율적으로 지원한다.
논문 핵심 내용
DeepSeek-V2 논문은 MLA(Multi-head Latent Attention)를 통해 KV 캐시를 극단적으로 압축하면서도 성능을 유지하는 방법을 제시했다. 전체 236B 파라미터 중 토큰당 21B만 활성화하는 MoE 구조와 MLA를 결합하여, 비용 효율적이면서도 고성능인 모델을 만들었다.
효율성 지표
| 항목 | DeepSeek 67B (이전) | DeepSeek-V2 | 개선 |
|---|---|---|---|
| KV 캐시 크기 | 기준 | -93.3% | 93.3% 감소 |
| 학습 비용 (GPU시간/1T토큰) | 300.6K | 172.8K | 42.5% 절감 |
| 최대 생성 처리량 | 기준 | 5.76배 | 5.76배 향상 |
| 컨텍스트 길이 | - | 128K 토큰 | - |
주요 벤치마크 성능 (21B 활성 파라미터)
| 벤치마크 | DeepSeek 67B | DeepSeek-V2 | LLaMA 3 70B |
|---|---|---|---|
| MMLU (5-shot) | 71.3% | 78.5% | 78.9% |
| GSM8K (8-shot) | 63.4% | 79.2% | - |
| HumanEval (0-shot) | 45.1% | 48.8% | 48.2% |
| MATH (4-shot) | - | 43.6% | - |
| MBPP (3-shot) | - | 66.6% | 68.6% |
21B 활성 파라미터만으로 70B 밀집 모델에 근접하거나 동등한 성능을 달성한 것이 핵심이다. KV 캐시가 93.3% 줄었기 때문에, H800 8장으로 초당 50K 토큰 이상의 생성 처리량을 달성했다. 이전 DeepSeek 67B 대비 5.76배의 처리량 향상은 MLA의 KV 캐시 압축이 실제 서빙에서 얼마나 큰 차이를 만드는지 보여준다.
관련 포스트
vLLM 의 다른글
- 이전글 [vLLM] FlashInfer: LLM 서빙에 특화된 어텐션 엔진
- 현재글 : [vLLM] Multi-head Latent Attention: KV 캐시를 압축하는 DeepSeek의 어텐션
- 다음글 [vLLM] Speculative Decoding: 드래프트 모델로 LLM 디코딩을 가속하는 원리
댓글