본문으로 건너뛰기

[vLLM] Lightning & Linear Attention: 선형 어텐션 구현

들어가며

표준 Transformer 어텐션은 시퀀스 길이에 대해 O(n^2) 복잡도를 가진다. 선형 어텐션(Linear Attention)은 이를 O(n)으로 줄이는 접근이다. vLLM은 vllm/v1/attention/backends/linear_attn.py에서 선형 어텐션 백엔드를, vllm/model_executor/layers/lightning_attn.py에서 Triton 기반 Lightning Attention 커널을 구현한다. Lightning Attention 논문(arxiv:2401.04658)의 아이디어를 기반으로 한다.

핵심 구조/코드 분석

LinearAttentionBackend

class LinearAttentionBackend(AttentionBackend):
    @staticmethod
    def get_name() -> str:
        return "LINEAR_ATTN"

    @classmethod
    def is_ssm(cls) -> bool:
        return True  # SSM(State Space Model)처럼 취급

핵심은 is_ssm() -> True다. 선형 어텐션은 KV 캐시를 전통적인 토큰별 저장 대신, SSM처럼 고정 크기 상태(state)로 관리한다.

LinearAttentionMetadata

@dataclass
class LinearAttentionMetadata:
    num_prefills: int
    num_prefill_tokens: int
    num_decodes: int
    num_decode_tokens: int
    query_start_loc: torch.Tensor
    seq_lens: torch.Tensor
    state_indices_tensor: torch.Tensor  # shape: [batch,]

state_indices_tensor가 핵심이다. 각 요청이 어떤 상태 슬롯을 사용하는지 가리킨다. 전통적 KV 캐시의 블록 테이블 대신 단일 인덱스로 상태를 참조한다.

메타데이터 빌드

class LinearAttentionMetadataBuilder(AttentionMetadataBuilder):
    reorder_batch_threshold: int = 1
    _cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

    def build(self, common_prefix_len, common_attn_metadata, fast_build=False):
        state_indices_tensor = mamba_get_block_table_tensor(
            common_attn_metadata.block_table_tensor,
            common_attn_metadata.seq_lens,
            self.kv_cache_spec,
            self.vllm_config.cache_config.mamba_cache_mode,
        )[:, 0]  # 첫 번째 블록만 사용 (상태 1개)

Mamba의 블록 테이블 유틸리티를 재사용하여 상태 인덱스를 구한다. [:, 0]으로 첫 번째 블록만 추출하는 것은, 선형 어텐션의 상태가 항상 1개의 블록에 담기기 때문이다.

Lightning Attention: Triton 커널

@triton.jit
def _fwd_diag_kernel(Q, K, V, Out, S, b, h, n, d, e, BLOCK, NUM_BLOCK, CBLOCK):
    """대각선 블록 연산: 같은 블록 내의 Q-K 어텐션"""
    off = tl.program_id(0)
    off_bh = off // NUM_BLOCK  # batch-head 인덱스
    off_block = off % NUM_BLOCK  # 시퀀스 내 블록 인덱스
    off_cblock = tl.program_id(1)  # 블록 내 서브블록 인덱스

    Q_block_ptr = Q + qk_offset + qk_block_offset + q_cblock_offset + ...
    K_trans_block_ptr = K + qk_offset + qk_block_offset + ...
    V_block_ptr = V + v_offset + v_block_offset + ...

Lightning Attention은 시퀀스를 블록으로 나누어 두 가지 연산을 분리한다:

  1. Diagonal blocks: 같은 블록 내의 Q-K 어텐션 (intra-block)
  2. Off-diagonal blocks: 이전 블록들의 정보를 상태(S)로 누적 (inter-block)

이 분리로 intra-block은 정확한 어텐션을, inter-block은 선형 시간 상태 업데이트를 수행한다.

Prefill/Decode 분리

num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
    split_decodes_and_prefills(
        common_attn_metadata, decode_threshold=self.reorder_batch_threshold
    )
)

reorder_batch_threshold=1이므로, 쿼리 길이가 1인 요청은 decode, 그 이상은 prefill로 분류한다. decode 시에는 상태에서 바로 읽기만 하면 되어 매우 빠르다.

왜 이 설계인가

  1. SSM 인프라 재사용: vLLM은 이미 Mamba 같은 SSM 모델을 지원하고 있다. 선형 어텐션도 고정 크기 상태를 사용하므로, SSM의 상태 관리 인프라(블록 테이블, 상태 인덱스, CUDA 그래프 지원)를 그대로 재사용할 수 있다. is_ssm() -> True가 이 재사용을 활성화한다.

  2. Diagonal/Off-diagonal 분리: Lightning Attention의 핵심 인사이트다. 소프트맥스 어텐션을 커널 형태의 선형 어텐션으로 근사할 때, 블록 내 연산은 정확하게, 블록 간 연산은 상태를 통해 근사한다. 이로써 O(n) 시간에 합리적인 품질을 달성한다.

  3. CUDA 그래프 제한: UNIFORM_SINGLE_TOKEN_DECODE만 지원한다. 선형 어텐션의 상태 업데이트가 시퀀스 길이에 따라 동적이므로, 가변 길이 배치에서의 CUDA 그래프 캡처가 어렵기 때문이다.

참고 자료

댓글

관련 포스트

vLLM 의 다른글