[vLLM] 기타 Attention Backends: GDN, Flex, Triton, DiffKV, MLA Sparse, CPU/ROCm
들어가며
vLLM은 FlashAttention 외에도 10개 이상의 어텐션 백엔드를 지원한다. 각 백엔드는 특정 하드웨어, 모델 아키텍처, 또는 최적화 전략에 특화되어 있다. vllm/v1/attention/backends/ 디렉토리의 주요 백엔드들을 하나씩 살펴본다.
핵심 구조/코드 분석
GDN Attention (GatedDeltaNet)
class GDNAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "GDN_ATTN"
@classmethod
def is_ssm(cls) -> bool:
return True
@dataclass
class GDNAttentionMetadata:
has_initial_state: torch.Tensor | None = None
spec_query_start_loc: torch.Tensor | None = None
non_spec_query_start_loc: torch.Tensor | None = None
spec_state_indices_tensor: torch.Tensor | None = None
spec_sequence_masks: torch.Tensor | None = None
GatedDeltaNet은 SSM 계열로, 투기적 디코딩과의 통합을 위해 spec_*와 non_spec_* 메타데이터를 분리 관리한다. has_initial_state 텐서로 각 요청이 초기 상태를 가지는지 추적한다.
FlexAttention: PyTorch 네이티브
from torch.nn.attention.flex_attention import (
BlockMask, create_block_mask, flex_attention, and_masks, or_masks,
)
create_block_mask_compiled = torch.compile(create_block_mask, fullgraph=True, mode="reduce-overhead")
flex_attention_compiled = torch.compile(flex_attention, fullgraph=True)
PyTorch 2.x의 flex_attention API를 사용하는 백엔드다. torch.compile으로 미리 컴파일하여 오버헤드를 줄인다. BlockMask를 사용하여 커스텀 마스크 패턴(causal, sliding window 등)을 유연하게 정의할 수 있다.
Triton Attention: 순수 Triton 구현
class TritonAttentionBackend(AttentionBackend):
# 2D 커널 최소 그리드 크기
MIN_LAUNCH_GRID_SIZE_2D = 128
# 타일 병렬 소프트맥스 세그먼트 수
NUM_PAR_SOFTMAX_SEGMENTS = 16
@dataclass
class TritonAttentionMetadata:
# context_len: 이전 반복까지의 토큰 수
# query_len: 이번 반복의 새 토큰 수
# seq_len: context_len + query_len
CUDA 커널 대신 순수 Triton으로 구현된 어텐션이다. FP8 KV 캐시와 ROCm AIter ops를 지원한다. Triton 기반이므로 AMD GPU에서도 동작한다.
FlashAttention DiffKV
class FlashAttentionDiffKVBackend(FlashAttentionBackend):
head_size_v: int = 128 # V의 head size가 K와 다를 수 있음
@classmethod
def set_head_size_v(cls, head_size_v: int) -> None:
cls.head_size_v = head_size_v
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_DIFFKV"
K와 V의 head size가 다른 모델(예: YOCO)을 위한 백엔드다. FlashAttention 기반이지만 KV 캐시 형상이 비대칭이므로 별도의 triton_reshape_and_cache_flash_diffkv 커널을 사용한다.
MLA Sparse Attention
vllm/v1/attention/backends/mla/
├── flashinfer_mla_sparse.py # FlashInfer 기반 MLA Sparse
├── flashmla_sparse.py # FlashMLA Sparse
├── rocm_aiter_mla_sparse.py # ROCm AIter MLA Sparse
├── xpu_mla_sparse.py # Intel XPU MLA Sparse
├── indexer.py # Sparse 인덱싱 유틸
DeepSeek-V2의 Multi-head Latent Attention(MLA)을 sparse하게 구현한 백엔드 모음이다. 각 하드웨어 플랫폼별 최적화가 따로 있다.
ROCm AIter Unified Attention
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
@classmethod
def get_preferred_block_size(cls, default_block_size: int) -> int:
logger.warning_once("[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64.")
return 64
AMD GPU용 AIter(AI Tensor Engine Runtime) 기반 통합 어텐션이다. 블록 크기를 64로 강제하는 것이 특징이며, 16의 배수인 블록 크기만 지원한다.
CPU Attention
vllm/v1/attention/backends/cpu_attn.py
CPU 추론을 위한 어텐션 백엔드도 있다. GPU가 없는 환경에서도 vLLM을 사용할 수 있게 해준다.
Mamba Attention 백엔드들
vllm/v1/attention/backends/mamba_attn.py # 공통 인터페이스
vllm/v1/attention/backends/mamba1_attn.py # Mamba-1
vllm/v1/attention/backends/mamba2_attn.py # Mamba-2
vllm/v1/attention/backends/short_conv_attn.py # Short Convolution
SSM(State Space Model) 계열 모델을 위한 백엔드다. Short Convolution은 Mamba의 인풋 게이팅 단계에서 사용되는 1D 합성곱을 처리한다.
왜 이 설계인가
-
레지스트리 기반 백엔드 선택: 각 백엔드가
get_name()으로 고유 이름을 제공하고, 레지스트리가 모델 설정과 하드웨어에 따라 적절한 백엔드를 자동 선택한다. 새 백엔드를 추가할 때 기존 코드를 수정할 필요가 없다. -
SSM 플래그:
is_ssm()이 True인 백엔드(GDN, Linear, Mamba)는 KV 캐시 대신 고정 크기 상태를 사용한다. 이 플래그로 스케줄러가 블록 할당 전략을 자동으로 조정한다. -
플랫폼별 MLA 분리: MLA의 sparse 연산은 GPU 아키텍처에 따라 최적 구현이 크게 다르다. NVIDIA(FlashInfer/FlashMLA), AMD(AIter), Intel(XPU) 각각의 최적화를 별도 파일로 관리하여, 플랫폼 간 간섭 없이 개발할 수 있다.
참고 자료
관련 포스트
vLLM 의 다른글
- 이전글 [vLLM] Lightning & Linear Attention: 선형 어텐션 구현
- 현재글 : [vLLM] 기타 Attention Backends: GDN, Flex, Triton, DiffKV, MLA Sparse, CPU/ROCm
- 다음글 [vLLM] 기타 Model Layers: Pooler, Resampler, Vocab Parallel Embedding 등
댓글