본문으로 건너뛰기

[vLLM] FlashAttention: IO-aware 타일링으로 어텐션 연산을 가속하는 원리

들어가며

표준 어텐션 연산은 Q·K^T 결과를 HBM(High Bandwidth Memory)에 한 번 쓰고, softmax를 위해 다시 읽고, V와 곱하기 위해 또 읽는다. 시퀀스 길이 N에 대해 O(N²) 크기의 중간 텐서가 HBM을 왕복하면서 IO가 병목이 된다. FlashAttention은 Q, K, V를 타일 단위로 SRAM에 올려 softmax까지 한 번에 처리함으로써 HBM 접근을 O(N²)에서 O(N)으로 줄인다.

공식 문서

vLLM 공식 문서: Attention Backends

핵심 구조/코드 분석

FlashAttentionBackend: 백엔드 등록 구조

vLLM은 어텐션 백엔드를 플러그인 형태로 관리한다. vllm/v1/attention/backends/flash_attn.py에서 FlashAttention 백엔드가 정의된다:

class FlashAttentionBackend(AttentionBackend):
    supported_dtypes: ClassVar[list[torch.dtype]] = [
        torch.float16, torch.bfloat16
    ]

    @staticmethod
    def get_kv_cache_shape(
        num_blocks, block_size, num_kv_heads, head_size, ...
    ) -> tuple[int, ...]:
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (2, num_blocks, block_size, num_kv_heads, head_size)

KV 캐시의 shape이 (2, num_blocks, block_size, num_kv_heads, head_size)임을 알 수 있다. 2는 K와 V를 구분하는 차원이고, PagedAttention의 블록 구조와 결합된다.

FlashAttentionImpl: 실제 forward 연산

class FlashAttentionImpl(AttentionImpl):
    def __init__(self, num_heads, head_size, scale, num_kv_heads, ...):
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.vllm_flash_attn_version = get_flash_attn_version(
            requires_alibi=alibi_slopes is not None,
            head_size=head_size,
        )

num_queries_per_kv는 GQA(Grouped Query Attention) 비율이다. FlashAttention은 GQA를 네이티브로 지원하여 KV 헤드를 여러 Q 헤드가 공유할 때도 효율적으로 동작한다.

FlashAttentionMetadata: 가변 길이 배치 처리

@dataclass
class FlashAttentionMetadata:
    num_actual_tokens: int
    max_query_len: int
    query_start_loc: torch.Tensor   # 각 요청의 쿼리 시작 위치
    max_seq_len: int
    seq_lens: torch.Tensor           # 각 요청의 시퀀스 길이
    block_table: torch.Tensor        # PagedAttention 블록 테이블
    slot_mapping: torch.Tensor

    # Cascade Attention 관련
    use_cascade: bool
    common_prefix_len: int

query_start_locseq_lens가 핵심이다. flash_attn_varlen_func에 이 정보를 넘겨서 하나의 커널 호출로 다양한 길이의 요청을 한 번에 처리한다.

FA3의 AOT 스케줄링

FlashAttention 3에서는 커널 실행 전에 워크로드 분배를 미리 계산한다:

class FlashAttentionMetadataBuilder:
    _cudagraph_support = (
        AttentionCGSupport.ALWAYS
        if get_flash_attn_version() == 3
        else AttentionCGSupport.UNIFORM_BATCH
    )

    def __init__(self, kv_cache_spec, layer_names, vllm_config, device):
        self.aot_schedule = get_flash_attn_version() == 3
        if self.use_full_cuda_graph and self.aot_schedule:
            # FA3 scheduler_metadata를 CUDA graph에 고정 할당
            max_batch_size = max(
                vllm_config.scheduler_config.max_num_seqs,
                self.max_cudagraph_size or 0,
            )
            self.scheduler_metadata = torch.zeros(
                1 + round_up(max_batch_size, 4) * 4,
                dtype=torch.int32,
            )

FA3는 CUDA graph를 모든 경우에 지원(ALWAYS)하지만, FA2는 UNIFORM_BATCH에서만 지원한다. 이것은 FA2의 max_query_len=1 최적화가 mixed prefill-decode와 호환되지 않기 때문이다.

Cascade Attention: 공통 프리픽스 최적화

# common prefix가 있을 때
prefix_scheduler_metadata = schedule(
    batch_size=1,
    cu_query_lens=cu_prefix_query_lens,
    max_query_len=num_actual_tokens,
    seqlens=prefix_kv_lens,
    max_seq_len=common_prefix_len,
    causal=False,
)
scheduler_metadata = schedule(
    batch_size=num_reqs,
    seqlens=suffix_kv_lens,
    max_seq_len=max_seq_len - common_prefix_len,
    causal=True,
)

배치 내 모든 요청이 공유하는 프리픽스가 있으면, 프리픽스 어텐션을 한 번만 계산하고 각 요청의 suffix만 개별 처리한다.

왜 이 설계인가

  1. IO-awareness: GPU의 SRAM은 HBM보다 10배 이상 빠르다. FlashAttention은 타일 단위로 SRAM에서 softmax까지 완료하여 HBM 왕복을 최소화한다.

  2. PagedAttention과의 시너지: block_tableslot_mapping을 통해 비연속 KV 캐시에 직접 접근한다. 메모리 효율과 연산 효율을 동시에 달성한다.

  3. CUDA Graph 호환성: FA3부터는 모든 배치 구성에서 CUDA graph를 지원하여 커널 런칭 오버헤드를 제거한다. 이것은 작은 배치에서 특히 중요하다.

  4. 정확한 어텐션: FlashAttention은 근사가 아니라 수학적으로 정확한 어텐션이다. online softmax 알고리즘을 사용하여 타일 단위 계산에서도 정확한 결과를 보장한다.

FlashAttention은 "알고리즘을 하드웨어에 맞추는" 시스템-알고리즘 공동 설계의 대표적 사례이다.

논문 핵심 내용

FlashAttention 논문의 핵심은 어텐션 연산을 IO-aware하게 재설계하여, HBM 접근 횟수를 O(N^2)에서 O(N)으로 줄인 것이다. 표준 어텐션은 N x N 크기의 중간 행렬(QK^T 결과)을 HBM에 쓰고 다시 읽는데, FlashAttention은 타일링과 online softmax 알고리즘을 결합하여 이 중간 행렬을 아예 HBM에 쓰지 않는다. SRAM에서 타일 단위로 softmax를 점진적으로 계산하고, 최종 결과만 HBM에 쓴다.

벤치마크 결과는 다음과 같다:

벤치마크 성능
BERT-large (seq 512) 학습 MLPerf 1.1 기록 대비 15% 벽시계 속도 향상
GPT-2 (seq 1K) 학습 표준 어텐션 대비 3배 속도 향상
Long-range Arena (seq 1K-4K) 2.4배 속도 향상
GPT-2 perplexity 0.7 포인트 개선 (더 빠른 학습 = 더 많은 반복)
Long-document 분류 6.4 포인트 정확도 향상
Path-X (seq 16K) 61.4% 정확도 (Transformer가 처음으로 random 이상 달성)
Path-256 (seq 64K) 63.1% 정확도

FlashAttention은 단순한 속도 향상을 넘어, 더 긴 시퀀스를 실용적으로 처리할 수 있게 해준다. 메모리 사용량이 시퀀스 길이에 선형으로만 증가하기 때문에, 기존에 OOM으로 불가능했던 16K, 64K 시퀀스 학습이 가능해졌다. Path-X와 Path-256에서 처음으로 random chance를 넘긴 것이 이를 증명한다. 수학적으로 정확한 어텐션이면서도 속도와 메모리를 동시에 개선한다는 점이 FlashAttention의 가장 강력한 특성이다.

댓글

관련 포스트

vLLM 의 다른글