본문으로 건너뛰기

[vLLM] 기타 Attention Backends: GDN, Flex, Triton, DiffKV, MLA Sparse, CPU/ROCm

들어가며

vLLM은 FlashAttention 외에도 10개 이상의 어텐션 백엔드를 지원한다. 각 백엔드는 특정 하드웨어, 모델 아키텍처, 또는 최적화 전략에 특화되어 있다. vllm/v1/attention/backends/ 디렉토리의 주요 백엔드들을 하나씩 살펴본다.

핵심 구조/코드 분석

GDN Attention (GatedDeltaNet)

class GDNAttentionBackend(AttentionBackend):
    @staticmethod
    def get_name() -> str:
        return "GDN_ATTN"

    @classmethod
    def is_ssm(cls) -> bool:
        return True

@dataclass
class GDNAttentionMetadata:
    has_initial_state: torch.Tensor | None = None
    spec_query_start_loc: torch.Tensor | None = None
    non_spec_query_start_loc: torch.Tensor | None = None
    spec_state_indices_tensor: torch.Tensor | None = None
    spec_sequence_masks: torch.Tensor | None = None

GatedDeltaNet은 SSM 계열로, 투기적 디코딩과의 통합을 위해 spec_*non_spec_* 메타데이터를 분리 관리한다. has_initial_state 텐서로 각 요청이 초기 상태를 가지는지 추적한다.

FlexAttention: PyTorch 네이티브

from torch.nn.attention.flex_attention import (
    BlockMask, create_block_mask, flex_attention, and_masks, or_masks,
)

create_block_mask_compiled = torch.compile(create_block_mask, fullgraph=True, mode="reduce-overhead")
flex_attention_compiled = torch.compile(flex_attention, fullgraph=True)

PyTorch 2.x의 flex_attention API를 사용하는 백엔드다. torch.compile으로 미리 컴파일하여 오버헤드를 줄인다. BlockMask를 사용하여 커스텀 마스크 패턴(causal, sliding window 등)을 유연하게 정의할 수 있다.

Triton Attention: 순수 Triton 구현

class TritonAttentionBackend(AttentionBackend):
    # 2D 커널 최소 그리드 크기
    MIN_LAUNCH_GRID_SIZE_2D = 128
    # 타일 병렬 소프트맥스 세그먼트 수
    NUM_PAR_SOFTMAX_SEGMENTS = 16

@dataclass
class TritonAttentionMetadata:
    # context_len: 이전 반복까지의 토큰 수
    # query_len: 이번 반복의 새 토큰 수
    # seq_len: context_len + query_len

CUDA 커널 대신 순수 Triton으로 구현된 어텐션이다. FP8 KV 캐시와 ROCm AIter ops를 지원한다. Triton 기반이므로 AMD GPU에서도 동작한다.

FlashAttention DiffKV

class FlashAttentionDiffKVBackend(FlashAttentionBackend):
    head_size_v: int = 128  # V의 head size가 K와 다를 수 있음

    @classmethod
    def set_head_size_v(cls, head_size_v: int) -> None:
        cls.head_size_v = head_size_v

    @staticmethod
    def get_name() -> str:
        return "FLASH_ATTN_DIFFKV"

K와 V의 head size가 다른 모델(예: YOCO)을 위한 백엔드다. FlashAttention 기반이지만 KV 캐시 형상이 비대칭이므로 별도의 triton_reshape_and_cache_flash_diffkv 커널을 사용한다.

MLA Sparse Attention

vllm/v1/attention/backends/mla/
├── flashinfer_mla_sparse.py     # FlashInfer 기반 MLA Sparse
├── flashmla_sparse.py           # FlashMLA Sparse
├── rocm_aiter_mla_sparse.py     # ROCm AIter MLA Sparse
├── xpu_mla_sparse.py            # Intel XPU MLA Sparse
├── indexer.py                   # Sparse 인덱싱 유틸

DeepSeek-V2의 Multi-head Latent Attention(MLA)을 sparse하게 구현한 백엔드 모음이다. 각 하드웨어 플랫폼별 최적화가 따로 있다.

ROCm AIter Unified Attention

class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
    @classmethod
    def get_preferred_block_size(cls, default_block_size: int) -> int:
        logger.warning_once("[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64.")
        return 64

AMD GPU용 AIter(AI Tensor Engine Runtime) 기반 통합 어텐션이다. 블록 크기를 64로 강제하는 것이 특징이며, 16의 배수인 블록 크기만 지원한다.

CPU Attention

vllm/v1/attention/backends/cpu_attn.py

CPU 추론을 위한 어텐션 백엔드도 있다. GPU가 없는 환경에서도 vLLM을 사용할 수 있게 해준다.

Mamba Attention 백엔드들

vllm/v1/attention/backends/mamba_attn.py     # 공통 인터페이스
vllm/v1/attention/backends/mamba1_attn.py    # Mamba-1
vllm/v1/attention/backends/mamba2_attn.py    # Mamba-2
vllm/v1/attention/backends/short_conv_attn.py # Short Convolution

SSM(State Space Model) 계열 모델을 위한 백엔드다. Short Convolution은 Mamba의 인풋 게이팅 단계에서 사용되는 1D 합성곱을 처리한다.

왜 이 설계인가

  1. 레지스트리 기반 백엔드 선택: 각 백엔드가 get_name()으로 고유 이름을 제공하고, 레지스트리가 모델 설정과 하드웨어에 따라 적절한 백엔드를 자동 선택한다. 새 백엔드를 추가할 때 기존 코드를 수정할 필요가 없다.

  2. SSM 플래그: is_ssm()이 True인 백엔드(GDN, Linear, Mamba)는 KV 캐시 대신 고정 크기 상태를 사용한다. 이 플래그로 스케줄러가 블록 할당 전략을 자동으로 조정한다.

  3. 플랫폼별 MLA 분리: MLA의 sparse 연산은 GPU 아키텍처에 따라 최적 구현이 크게 다르다. NVIDIA(FlashInfer/FlashMLA), AMD(AIter), Intel(XPU) 각각의 최적화를 별도 파일로 관리하여, 플랫폼 간 간섭 없이 개발할 수 있다.

참고 자료

댓글

관련 포스트

vLLM 의 다른글