본문으로 건너뛰기

[SGLang] FlashAttention 백엔드: IO-aware 타일링 어텐션의 구현

들어가며

Standard attention은 O(N^2) 크기의 어텐션 행렬을 HBM(High Bandwidth Memory)에 materialize한다. 시퀀스 길이가 길어질수록 메모리 사용량과 HBM 접근 횟수가 급증하여 연산 병목이 발생한다. FlashAttention은 이 문제를 IO-aware 타일링으로 해결한다. 어텐션 행렬을 SRAM 크기에 맞게 타일로 분할하여 HBM 접근을 최소화하면서도 정확한 결과를 보장한다.

SGLang의 FlashAttentionBackend는 FlashAttention v3/v4 커널을 래핑하여 Paged KV Cache, Sliding Window, CUDA Graph, Speculative Decoding을 지원한다.

이 글에서는 python/sglang/srt/layers/attention/flashattention_backend.py를 분석한다.

Standard Attention vs FlashAttention

Before: Standard Attention

  Q (N x d)     K^T (d x N)
     │              │
     ▼              ▼
  ┌──────────────────────┐
  │  S = Q @ K^T         │  ← O(N^2) HBM 쓰기
  │  (N x N 행렬)         │
  └──────────┬───────────┘
             ▼
  ┌──────────────────────┐
  │  P = softmax(S)      │  ← O(N^2) HBM 읽기/쓰기
  │  (N x N 행렬)         │
  └──────────┬───────────┘
             ▼
  ┌──────────────────────┐
  │  O = P @ V           │  ← O(N^2) HBM 읽기
  └──────────────────────┘

  총 HBM 접근: O(N^2 * d) 바이트
  추가 메모리: O(N^2)

After: FlashAttention (IO-aware Tiling)

  Q를 타일로 분할: Q_1, Q_2, ..., Q_T (각 B_r x d)
  K, V를 타일로 분할: K_1, V_1, K_2, V_2, ...  (각 B_c x d)

  for each Q_i:
    for each K_j, V_j:
      ┌─────────────────────────────────────┐
      │  SRAM에서 처리 (on-chip)              │
      │  S_ij = Q_i @ K_j^T   (B_r x B_c)  │
      │  P_ij = softmax(S_ij)              │
      │  O_i += P_ij @ V_j    (누적)        │
      └─────────────────────────────────────┘
      HBM에서 Q_i, K_j, V_j만 로드 → O_i만 저장

  총 HBM 접근: O(N^2 * d^2 / M) 바이트  (M = SRAM 크기)
  추가 메모리: O(N) — softmax 통계만 저장

FlashAttention의 핵심은 online softmax 알고리즘이다. 전체 어텐션 행렬을 materialize하지 않고, 타일 단위로 softmax의 running maximum과 sum을 유지하면서 정확한 결과를 구한다.

FlashAttentionMetadata: 레이어간 메타데이터 재사용

모든 레이어가 동일한 메타데이터를 공유하므로, 첫 번째 레이어에서 한 번만 계산한다.

@dataclass
class FlashAttentionMetadata:
    cache_seqlens_int32: torch.Tensor = None
    max_seq_len_q: int = 1
    max_seq_len_k: int = 0
    cu_seqlens_q: torch.Tensor = None
    cu_seqlens_k: torch.Tensor = None
    window_size: tuple = (-1, -1)
    page_table: torch.Tensor = None
    swa_page_table: torch.Tensor = None

cu_seqlens_qcu_seqlens_k는 cumulative sequence lengths로, variable-length 배치에서 각 시퀀스의 시작 위치를 나타낸다. page_table은 Paged KV Cache의 블록 인덱스 매핑이다.

FlashAttentionBackend 초기화

class FlashAttentionBackend(AttentionBackend):
    def __init__(self, model_runner, skip_prefill=False,
                 speculative_step_id=0, topk=0,
                 speculative_num_steps=0, fa_impl_ver=3):
        self.forward_metadata: FlashAttentionMetadata = None
        self.max_context_len = model_runner.model_config.context_len
        self.page_size = model_runner.page_size
        self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA

        # FA3 / FA4 버전 선택
        if self.fa_impl_ver == 3:
            from sgl_kernel.flash_attn import (
                flash_attn_varlen_func,
                flash_attn_with_kvcache,
            )
        elif self.fa_impl_ver == 4:
            from sglang.jit_kernel.flash_attention_v4 import (
                flash_attn_varlen_func,
                flash_attn_with_kvcache,
            )

FA3와 FA4는 동일한 백엔드 클래스에서 fa_impl_ver 파라미터로 구분된다. FA4는 SM90+(Hopper 이상)에서 더 최적화된 커널을 사용한다.

Deterministic Inference 지원

self.num_splits = (
    1
    if model_runner.server_args.enable_deterministic_inference
    or (self.fa_impl_ver == 4
        and not model_runner.server_args.disable_cuda_graph)
    else 0
)

num_splits=1은 단일 split으로 실행하여 결정론적(deterministic) 결과를 보장한다. num_splits=0은 자동 heuristic으로 최적의 split 수를 결정한다. FA4는 CUDA Graph와 함께 사용할 때 num_splits=0을 지원하지 않으므로 1로 강제한다.

init_forward_metadata: 모드별 메타데이터 초기화

Decode 모드의 메타데이터 설정은 다음과 같다.

def init_forward_metadata(self, forward_batch: ForwardBatch):
    metadata = FlashAttentionMetadata()
    seqlens_in_batch = forward_batch.seq_lens

    if forward_batch.forward_mode.is_decode_or_idle():
        metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
        metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
        metadata.cu_seqlens_q = torch.arange(
            0, batch_size + 1, dtype=torch.int32, device=device
        )
        metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
            forward_batch.req_pool_indices, : metadata.max_seq_len_k
        ]

Decode에서 cu_seqlens_q는 단순히 0, 1, 2, ..., batch_size다. 각 요청이 정확히 1개의 새 토큰을 생성하기 때문이다. Extend 모드에서는 요청별로 다른 길이의 시퀀스를 처리하므로 cumulative sum으로 계산한다.

forward_extend: Prefill 경로

def forward_extend(self, q, k, v, layer, forward_batch,
                   save_kv_cache=True, q_rope=None, k_rope=None,
                   sinks=None):
    if k is not None:
        if save_kv_cache and not is_cp_mode:
            if not self.use_mla:
                forward_batch.token_to_kv_pool.set_kv_buffer(
                    layer, cache_loc, k, v, layer.k_scale, layer.v_scale
                )
            else:
                forward_batch.token_to_kv_pool.set_mla_kv_buffer(
                    layer, cache_loc, k, k_rope,
                )

    is_swa_layer = (
        layer.sliding_window_size is not None and layer.sliding_window_size > -1
    )
    window_size = (layer.sliding_window_size, 0) if is_swa_layer else (-1, -1)

KV 캐시 저장 시 일반 MHA와 MLA를 구분한다. MLA는 K와 K_rope를 별도로 저장하며, V는 K에 포함된(absorbed) 형태로 관리된다. Sliding Window Attention(SWA) 레이어는 window_size 튜플로 윈도우 범위를 지정한다.

forward_decode: Decode 경로

Decode는 flash_attn_with_kvcache를 사용하여 Paged KV Cache에서 직접 어텐션을 계산한다.

# forward_decode에서 핵심 호출
o = self.flash_attn_with_kvcache(
    q=q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
    k_cache=k_cache,
    v_cache=v_cache,
    cache_seqlens=metadata.cache_seqlens_int32,
    page_table=page_table,
    softmax_scale=layer.scaling,
    causal=True,
    window_size=window_size,
    num_splits=self.num_splits,
)

page_table로 비연속 메모리 블록에 분산된 KV 캐시에 접근한다. 이것이 FlashAttention 커널이 Paged KV Cache를 지원하는 핵심 인터페이스다.

Local Attention 지원

Llama 4처럼 iRoPE를 사용하는 모델은 Local Attention이 필요하다.

@dataclass
class LocalAttentionMetadata:
    local_query_start_loc: torch.Tensor = None
    local_seqused_k: torch.Tensor = None
    local_block_table: torch.Tensor = None
    local_max_query_len: int = 0
    local_max_seq_len: int = 0

Local Attention에서는 시퀀스를 attention_chunk_size 단위의 청크로 나누어 각 청크 내에서만 어텐션을 계산한다. 별도의 local_block_tablelocal_seqused_k로 청크별 KV 캐시 범위를 관리한다.

성능 비교

항목 Standard Attention FlashAttention
HBM 접근 O(N^2 * d) O(N^2 * d^2 / M)
추가 메모리 O(N^2) O(N)
N=4K, d=128 ~4GB 어텐션 행렬 어텐션 행렬 없음
N=128K, d=128 ~128GB (불가능) SRAM 타일링으로 처리
Backward pass O(N^2) 재계산 Softmax 통계로 recompute

M은 SRAM 크기(~192KB on H100)이며, d^2/M 팩터가 HBM 접근을 크게 줄인다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글