[SGLang] FlashAttention 백엔드: IO-aware 타일링 어텐션의 구현
들어가며
Standard attention은 O(N^2) 크기의 어텐션 행렬을 HBM(High Bandwidth Memory)에 materialize한다. 시퀀스 길이가 길어질수록 메모리 사용량과 HBM 접근 횟수가 급증하여 연산 병목이 발생한다. FlashAttention은 이 문제를 IO-aware 타일링으로 해결한다. 어텐션 행렬을 SRAM 크기에 맞게 타일로 분할하여 HBM 접근을 최소화하면서도 정확한 결과를 보장한다.
SGLang의 FlashAttentionBackend는 FlashAttention v3/v4 커널을 래핑하여 Paged KV Cache, Sliding Window, CUDA Graph, Speculative Decoding을 지원한다.
이 글에서는 python/sglang/srt/layers/attention/flashattention_backend.py를 분석한다.
Standard Attention vs FlashAttention
Before: Standard Attention
Q (N x d) K^T (d x N)
│ │
▼ ▼
┌──────────────────────┐
│ S = Q @ K^T │ ← O(N^2) HBM 쓰기
│ (N x N 행렬) │
└──────────┬───────────┘
▼
┌──────────────────────┐
│ P = softmax(S) │ ← O(N^2) HBM 읽기/쓰기
│ (N x N 행렬) │
└──────────┬───────────┘
▼
┌──────────────────────┐
│ O = P @ V │ ← O(N^2) HBM 읽기
└──────────────────────┘
총 HBM 접근: O(N^2 * d) 바이트
추가 메모리: O(N^2)
After: FlashAttention (IO-aware Tiling)
Q를 타일로 분할: Q_1, Q_2, ..., Q_T (각 B_r x d)
K, V를 타일로 분할: K_1, V_1, K_2, V_2, ... (각 B_c x d)
for each Q_i:
for each K_j, V_j:
┌─────────────────────────────────────┐
│ SRAM에서 처리 (on-chip) │
│ S_ij = Q_i @ K_j^T (B_r x B_c) │
│ P_ij = softmax(S_ij) │
│ O_i += P_ij @ V_j (누적) │
└─────────────────────────────────────┘
HBM에서 Q_i, K_j, V_j만 로드 → O_i만 저장
총 HBM 접근: O(N^2 * d^2 / M) 바이트 (M = SRAM 크기)
추가 메모리: O(N) — softmax 통계만 저장
FlashAttention의 핵심은 online softmax 알고리즘이다. 전체 어텐션 행렬을 materialize하지 않고, 타일 단위로 softmax의 running maximum과 sum을 유지하면서 정확한 결과를 구한다.
FlashAttentionMetadata: 레이어간 메타데이터 재사용
모든 레이어가 동일한 메타데이터를 공유하므로, 첫 번째 레이어에서 한 번만 계산한다.
@dataclass
class FlashAttentionMetadata:
cache_seqlens_int32: torch.Tensor = None
max_seq_len_q: int = 1
max_seq_len_k: int = 0
cu_seqlens_q: torch.Tensor = None
cu_seqlens_k: torch.Tensor = None
window_size: tuple = (-1, -1)
page_table: torch.Tensor = None
swa_page_table: torch.Tensor = None
cu_seqlens_q와 cu_seqlens_k는 cumulative sequence lengths로, variable-length 배치에서 각 시퀀스의 시작 위치를 나타낸다. page_table은 Paged KV Cache의 블록 인덱스 매핑이다.
FlashAttentionBackend 초기화
class FlashAttentionBackend(AttentionBackend):
def __init__(self, model_runner, skip_prefill=False,
speculative_step_id=0, topk=0,
speculative_num_steps=0, fa_impl_ver=3):
self.forward_metadata: FlashAttentionMetadata = None
self.max_context_len = model_runner.model_config.context_len
self.page_size = model_runner.page_size
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
# FA3 / FA4 버전 선택
if self.fa_impl_ver == 3:
from sgl_kernel.flash_attn import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
elif self.fa_impl_ver == 4:
from sglang.jit_kernel.flash_attention_v4 import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
FA3와 FA4는 동일한 백엔드 클래스에서 fa_impl_ver 파라미터로 구분된다. FA4는 SM90+(Hopper 이상)에서 더 최적화된 커널을 사용한다.
Deterministic Inference 지원
self.num_splits = (
1
if model_runner.server_args.enable_deterministic_inference
or (self.fa_impl_ver == 4
and not model_runner.server_args.disable_cuda_graph)
else 0
)
num_splits=1은 단일 split으로 실행하여 결정론적(deterministic) 결과를 보장한다. num_splits=0은 자동 heuristic으로 최적의 split 수를 결정한다. FA4는 CUDA Graph와 함께 사용할 때 num_splits=0을 지원하지 않으므로 1로 강제한다.
init_forward_metadata: 모드별 메타데이터 초기화
Decode 모드의 메타데이터 설정은 다음과 같다.
def init_forward_metadata(self, forward_batch: ForwardBatch):
metadata = FlashAttentionMetadata()
seqlens_in_batch = forward_batch.seq_lens
if forward_batch.forward_mode.is_decode_or_idle():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
Decode에서 cu_seqlens_q는 단순히 0, 1, 2, ..., batch_size다. 각 요청이 정확히 1개의 새 토큰을 생성하기 때문이다. Extend 모드에서는 요청별로 다른 길이의 시퀀스를 처리하므로 cumulative sum으로 계산한다.
forward_extend: Prefill 경로
def forward_extend(self, q, k, v, layer, forward_batch,
save_kv_cache=True, q_rope=None, k_rope=None,
sinks=None):
if k is not None:
if save_kv_cache and not is_cp_mode:
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, cache_loc, k, k_rope,
)
is_swa_layer = (
layer.sliding_window_size is not None and layer.sliding_window_size > -1
)
window_size = (layer.sliding_window_size, 0) if is_swa_layer else (-1, -1)
KV 캐시 저장 시 일반 MHA와 MLA를 구분한다. MLA는 K와 K_rope를 별도로 저장하며, V는 K에 포함된(absorbed) 형태로 관리된다. Sliding Window Attention(SWA) 레이어는 window_size 튜플로 윈도우 범위를 지정한다.
forward_decode: Decode 경로
Decode는 flash_attn_with_kvcache를 사용하여 Paged KV Cache에서 직접 어텐션을 계산한다.
# forward_decode에서 핵심 호출
o = self.flash_attn_with_kvcache(
q=q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k_cache=k_cache,
v_cache=v_cache,
cache_seqlens=metadata.cache_seqlens_int32,
page_table=page_table,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
num_splits=self.num_splits,
)
page_table로 비연속 메모리 블록에 분산된 KV 캐시에 접근한다. 이것이 FlashAttention 커널이 Paged KV Cache를 지원하는 핵심 인터페이스다.
Local Attention 지원
Llama 4처럼 iRoPE를 사용하는 모델은 Local Attention이 필요하다.
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor = None
local_seqused_k: torch.Tensor = None
local_block_table: torch.Tensor = None
local_max_query_len: int = 0
local_max_seq_len: int = 0
Local Attention에서는 시퀀스를 attention_chunk_size 단위의 청크로 나누어 각 청크 내에서만 어텐션을 계산한다. 별도의 local_block_table과 local_seqused_k로 청크별 KV 캐시 범위를 관리한다.
성능 비교
| 항목 | Standard Attention | FlashAttention |
|---|---|---|
| HBM 접근 | O(N^2 * d) | O(N^2 * d^2 / M) |
| 추가 메모리 | O(N^2) | O(N) |
| N=4K, d=128 | ~4GB 어텐션 행렬 | 어텐션 행렬 없음 |
| N=128K, d=128 | ~128GB (불가능) | SRAM 타일링으로 처리 |
| Backward pass | O(N^2) 재계산 | Softmax 통계로 recompute |
M은 SRAM 크기(~192KB on H100)이며, d^2/M 팩터가 HBM 접근을 크게 줄인다.
관련 포스트
참고
관련 포스트
- [SGLang] Lightning Attention: 고속 선형 어텐션 구현
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [sglang] [SGLang] Blackwell(B200)에서 Diffusion Attention 성능을 7배 끌어올리는 Triton 커널 최적화 분석
- [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
SGLang 의 다른글
- 이전글 [SGLang] Attention Registry: 동적 백엔드 선택 메커니즘
- 현재글 : [SGLang] FlashAttention 백엔드: IO-aware 타일링 어텐션의 구현
- 다음글 [SGLang] FlashInfer: 래그드 텐서 어텐션 엔진
댓글