[vLLM] FlashAttention: IO-aware 타일링으로 어텐션 연산을 가속하는 원리
들어가며
표준 어텐션 연산은 Q·K^T 결과를 HBM(High Bandwidth Memory)에 한 번 쓰고, softmax를 위해 다시 읽고, V와 곱하기 위해 또 읽는다. 시퀀스 길이 N에 대해 O(N²) 크기의 중간 텐서가 HBM을 왕복하면서 IO가 병목이 된다. FlashAttention은 Q, K, V를 타일 단위로 SRAM에 올려 softmax까지 한 번에 처리함으로써 HBM 접근을 O(N²)에서 O(N)으로 줄인다.
- 논문: FlashAttention (arxiv 2205.14135), FlashAttention-2 (arxiv 2307.08691)
- 공식 문서: https://docs.vllm.ai
공식 문서
vLLM 공식 문서: Attention Backends
핵심 구조/코드 분석
FlashAttentionBackend: 백엔드 등록 구조
vLLM은 어텐션 백엔드를 플러그인 형태로 관리한다. vllm/v1/attention/backends/flash_attn.py에서 FlashAttention 백엔드가 정의된다:
class FlashAttentionBackend(AttentionBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16, torch.bfloat16
]
@staticmethod
def get_kv_cache_shape(
num_blocks, block_size, num_kv_heads, head_size, ...
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
KV 캐시의 shape이 (2, num_blocks, block_size, num_kv_heads, head_size)임을 알 수 있다. 2는 K와 V를 구분하는 차원이고, PagedAttention의 블록 구조와 결합된다.
FlashAttentionImpl: 실제 forward 연산
class FlashAttentionImpl(AttentionImpl):
def __init__(self, num_heads, head_size, scale, num_kv_heads, ...):
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=alibi_slopes is not None,
head_size=head_size,
)
num_queries_per_kv는 GQA(Grouped Query Attention) 비율이다. FlashAttention은 GQA를 네이티브로 지원하여 KV 헤드를 여러 Q 헤드가 공유할 때도 효율적으로 동작한다.
FlashAttentionMetadata: 가변 길이 배치 처리
@dataclass
class FlashAttentionMetadata:
num_actual_tokens: int
max_query_len: int
query_start_loc: torch.Tensor # 각 요청의 쿼리 시작 위치
max_seq_len: int
seq_lens: torch.Tensor # 각 요청의 시퀀스 길이
block_table: torch.Tensor # PagedAttention 블록 테이블
slot_mapping: torch.Tensor
# Cascade Attention 관련
use_cascade: bool
common_prefix_len: int
query_start_loc와 seq_lens가 핵심이다. flash_attn_varlen_func에 이 정보를 넘겨서 하나의 커널 호출로 다양한 길이의 요청을 한 번에 처리한다.
FA3의 AOT 스케줄링
FlashAttention 3에서는 커널 실행 전에 워크로드 분배를 미리 계산한다:
class FlashAttentionMetadataBuilder:
_cudagraph_support = (
AttentionCGSupport.ALWAYS
if get_flash_attn_version() == 3
else AttentionCGSupport.UNIFORM_BATCH
)
def __init__(self, kv_cache_spec, layer_names, vllm_config, device):
self.aot_schedule = get_flash_attn_version() == 3
if self.use_full_cuda_graph and self.aot_schedule:
# FA3 scheduler_metadata를 CUDA graph에 고정 할당
max_batch_size = max(
vllm_config.scheduler_config.max_num_seqs,
self.max_cudagraph_size or 0,
)
self.scheduler_metadata = torch.zeros(
1 + round_up(max_batch_size, 4) * 4,
dtype=torch.int32,
)
FA3는 CUDA graph를 모든 경우에 지원(ALWAYS)하지만, FA2는 UNIFORM_BATCH에서만 지원한다. 이것은 FA2의 max_query_len=1 최적화가 mixed prefill-decode와 호환되지 않기 때문이다.
Cascade Attention: 공통 프리픽스 최적화
# common prefix가 있을 때
prefix_scheduler_metadata = schedule(
batch_size=1,
cu_query_lens=cu_prefix_query_lens,
max_query_len=num_actual_tokens,
seqlens=prefix_kv_lens,
max_seq_len=common_prefix_len,
causal=False,
)
scheduler_metadata = schedule(
batch_size=num_reqs,
seqlens=suffix_kv_lens,
max_seq_len=max_seq_len - common_prefix_len,
causal=True,
)
배치 내 모든 요청이 공유하는 프리픽스가 있으면, 프리픽스 어텐션을 한 번만 계산하고 각 요청의 suffix만 개별 처리한다.
왜 이 설계인가
-
IO-awareness: GPU의 SRAM은 HBM보다 10배 이상 빠르다. FlashAttention은 타일 단위로 SRAM에서 softmax까지 완료하여 HBM 왕복을 최소화한다.
-
PagedAttention과의 시너지:
block_table과slot_mapping을 통해 비연속 KV 캐시에 직접 접근한다. 메모리 효율과 연산 효율을 동시에 달성한다. -
CUDA Graph 호환성: FA3부터는 모든 배치 구성에서 CUDA graph를 지원하여 커널 런칭 오버헤드를 제거한다. 이것은 작은 배치에서 특히 중요하다.
-
정확한 어텐션: FlashAttention은 근사가 아니라 수학적으로 정확한 어텐션이다. online softmax 알고리즘을 사용하여 타일 단위 계산에서도 정확한 결과를 보장한다.
FlashAttention은 "알고리즘을 하드웨어에 맞추는" 시스템-알고리즘 공동 설계의 대표적 사례이다.
논문 핵심 내용
FlashAttention 논문의 핵심은 어텐션 연산을 IO-aware하게 재설계하여, HBM 접근 횟수를 O(N^2)에서 O(N)으로 줄인 것이다. 표준 어텐션은 N x N 크기의 중간 행렬(QK^T 결과)을 HBM에 쓰고 다시 읽는데, FlashAttention은 타일링과 online softmax 알고리즘을 결합하여 이 중간 행렬을 아예 HBM에 쓰지 않는다. SRAM에서 타일 단위로 softmax를 점진적으로 계산하고, 최종 결과만 HBM에 쓴다.
벤치마크 결과는 다음과 같다:
| 벤치마크 | 성능 |
|---|---|
| BERT-large (seq 512) 학습 | MLPerf 1.1 기록 대비 15% 벽시계 속도 향상 |
| GPT-2 (seq 1K) 학습 | 표준 어텐션 대비 3배 속도 향상 |
| Long-range Arena (seq 1K-4K) | 2.4배 속도 향상 |
| GPT-2 perplexity | 0.7 포인트 개선 (더 빠른 학습 = 더 많은 반복) |
| Long-document 분류 | 6.4 포인트 정확도 향상 |
| Path-X (seq 16K) | 61.4% 정확도 (Transformer가 처음으로 random 이상 달성) |
| Path-256 (seq 64K) | 63.1% 정확도 |
FlashAttention은 단순한 속도 향상을 넘어, 더 긴 시퀀스를 실용적으로 처리할 수 있게 해준다. 메모리 사용량이 시퀀스 길이에 선형으로만 증가하기 때문에, 기존에 OOM으로 불가능했던 16K, 64K 시퀀스 학습이 가능해졌다. Path-X와 Path-256에서 처음으로 random chance를 넘긴 것이 이를 증명한다. 수학적으로 정확한 어텐션이면서도 속도와 메모리를 동시에 개선한다는 점이 FlashAttention의 가장 강력한 특성이다.
관련 포스트
vLLM 의 다른글
- 이전글 [vLLM] Chunked Prefill: 긴 프롬프트를 청크 단위로 분할 처리하는 기법
- 현재글 : [vLLM] FlashAttention: IO-aware 타일링으로 어텐션 연산을 가속하는 원리
- 다음글 [vLLM] FlashInfer: LLM 서빙에 특화된 어텐션 엔진
댓글