본문으로 건너뛰기

[SGLang] FlashInfer: 래그드 텐서 어텐션 엔진

들어가며

FlashInfer는 LLM serving에 특화된 어텐션 커널 라이브러리다. FlashAttention과 동일한 IO-aware 원리를 사용하되, serving 환경에서 빈번한 가변 길이 배치(ragged batch)를 일급 시민(first-class citizen)으로 다룬다. SGLang에서 FlashInfer는 기본(default) 어텐션 백엔드로, Prefill의 Ragged KV Cache와 Decode의 Paged KV Cache를 모두 지원한다.

이 글에서는 python/sglang/srt/layers/attention/flashinfer_backend.py를 분석한다.

전체 구조

FlashInfer 백엔드는 세 종류의 wrapper를 조합하여 모든 forward 모드를 처리한다.

┌──────────────────────────────────────────────────────┐
│              FlashInferAttnBackend                    │
│                                                      │
│  ┌─────────────────────────────────────────────────┐ │
│  │  Prefill (Extend)                               │ │
│  │                                                 │ │
│  │  ┌─ use_ragged=True ─────────────────────────┐  │ │
│  │  │  prefill_wrapper_ragged                    │  │ │
│  │  │  (BatchPrefillWithRaggedKVCacheWrapper)    │  │ │
│  │  │  → 새 토큰의 Q,K,V로 직접 어텐션            │  │ │
│  │  │  + prefill_wrappers_paged (prefix cache)  │  │ │
│  │  │  → merge_state()로 결과 합산               │  │ │
│  │  └────────────────────────────────────────────┘  │ │
│  │                                                 │ │
│  │  ┌─ use_ragged=False ────────────────────────┐  │ │
│  │  │  prefill_wrappers_paged                   │  │ │
│  │  │  (BatchPrefillWithPagedKVCacheWrapper)     │  │ │
│  │  │  → KV cache에서 전체 어텐션                 │  │ │
│  │  └────────────────────────────────────────────┘  │ │
│  └─────────────────────────────────────────────────┘ │
│                                                      │
│  ┌─────────────────────────────────────────────────┐ │
│  │  Decode                                         │ │
│  │  decode_wrappers                                │ │
│  │  (BatchDecodeWithPagedKVCacheWrapper)            │ │
│  │  → Paged KV Cache에서 1-token 어텐션             │ │
│  └─────────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────────┘

FlashInfer의 Ragged Tensor 처리

Ragged tensor는 각 행의 길이가 다른 2D 텐서다. LLM serving에서 배치 내 시퀀스들은 거의 항상 다른 길이를 가진다. FlashInfer는 이를 패딩 없이 처리한다.

  Padded Batch (기존 방식):          Ragged Batch (FlashInfer):
  ┌──────────────────────┐          ┌──────────────┐
  │ seq1 ████████░░░░░░░░│          │ seq1 ████████│
  │ seq2 ████░░░░░░░░░░░░│          │ seq2 ████    │
  │ seq3 ██████████████░░│          │ seq3 ██████████████│
  └──────────────────────┘          └──────────────┘
  max_len으로 패딩 → 메모리 낭비       indptr로 경계 표시 → 무낭비

초기화: Wrapper 구성

class FlashInferAttnBackend(AttentionBackend):
    def __init__(self, model_runner, skip_prefill=False,
                 kv_indptr_buf=None, kv_last_page_len_buf=None,
                 init_new_workspace=False):
        self.prefill_backend = "fa2"
        self.decode_backend = "fa2"
        self.decode_use_tensor_cores = should_use_tensor_core(
            kv_cache_dtype=model_runner.kv_cache_dtype,
            num_attention_heads=...,
            num_kv_heads=...,
        )

should_use_tensor_core는 GQA(Grouped Query Attention)에서 Q head 수와 KV head 수의 비율에 따라 tensor core 사용 여부를 결정한다. GQA 비율이 높으면 tensor core가 더 효율적이다.

Sliding Window / Cross Attention 이중 래퍼

if model_runner.sliding_window_size is not None:
    self.num_wrappers = 2
    self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
elif model_runner.model_config.is_encoder_decoder:
    self.num_wrappers = 2
    self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
else:
    self.num_wrappers = 1
    self.dispatch_reason = None

Sliding Window 모델(Mistral)이나 encoder-decoder 모델(Whisper)은 두 종류의 어텐션이 공존한다. 이를 위해 두 개의 wrapper를 생성하고, 레이어별로 적절한 wrapper를 선택한다.

def _get_wrapper_idx(self, layer: RadixAttention):
    if self.num_wrappers == 1:
        return 0
    if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
        return layer.sliding_window_size == -1  # SWA가 아닌 레이어 → 1
    if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
        return layer.is_cross_attention  # cross attention 레이어 → 1

init_forward_metadata: Plan 단계

FlashInfer는 "plan-then-execute" 패턴을 사용한다. init_forward_metadata에서 인덱스 계산(plan)을 수행하고, forward_extend/forward_decode에서 실제 어텐션을 실행(execute)한다.

def init_forward_metadata(self, forward_batch: ForwardBatch):
    if forward_batch.forward_mode.is_decode_or_idle():
        self.indices_updater_decode.update(
            forward_batch.req_pool_indices,
            forward_batch.seq_lens,
            forward_batch.seq_lens_cpu,
            forward_batch.seq_lens_sum,
            decode_wrappers=self.decode_wrappers,
        )
        self.forward_metadata = DecodeMetadata(self.decode_wrappers)

indices_updater_decode.updatereq_to_token 테이블에서 KV 인덱스(kv_indices)를 생성하고, decode_wrappersbegin_forward(plan)를 호출한다. 이 plan 단계에서 GPU 커널의 block 할당과 메모리 접근 패턴이 결정된다.

forward_extend: Ragged + Paged 이중 어텐션

Extend(prefill) 모드에서 FlashInfer는 두 가지 경로를 지원한다.

def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True):
    if not self.forward_metadata.use_ragged:
        # 모든 KV를 캐시에 저장한 뒤, paged wrapper로 어텐션
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, cache_loc, k, v, layer.k_scale, layer.v_scale
            )
        o = prefill_wrapper_paged.forward(
            q.view(-1, layer.tp_q_head_num, layer.head_dim),
            forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
            causal=causal, sm_scale=layer.scaling,
        )
    else:
        if self.forward_metadata.extend_no_prefix:
            # prefix 없음: ragged wrapper만으로 처리
            o = self.prefill_wrapper_ragged.forward(
                q.view(-1, layer.tp_q_head_num, layer.head_dim),
                k.view(-1, layer.tp_k_head_num, layer.head_dim),
                v.view(-1, layer.tp_v_head_num, layer.head_dim),
                causal=True, sm_scale=layer.scaling,
            )
        else:
            # prefix 있음: ragged(새 토큰) + paged(캐시) → merge
            o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(...)
            o2, s2 = prefill_wrapper_paged.forward_return_lse(...)
            o, _ = merge_state(o1, s1, o2, s2)

extend_no_prefix인 경우, 새로 입력된 토큰의 Q, K, V로 직접 ragged 어텐션을 수행한다. 이것이 가장 빠른 경로다.

Prefix cache가 존재하면 cascade attention을 사용한다. 새 토큰 간 ragged 어텐션(o1, s1)과 캐시된 prefix에 대한 paged 어텐션(o2, s2)을 각각 계산한 뒤, merge_state로 softmax 통계를 이용해 정확하게 합산한다.

forward_decode: Paged KV Cache 어텐션

def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True):
    decode_wrapper = self.forward_metadata.decode_wrappers[
        self._get_wrapper_idx(layer)
    ]
    if k is not None:
        if save_kv_cache:
            forward_batch.token_to_kv_pool.set_kv_buffer(
                layer, cache_loc, k, v, layer.k_scale, layer.v_scale
            )
    o = decode_wrapper.forward(
        q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
        forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
        sm_scale=layer.scaling,
        logits_soft_cap=layer.logit_cap,
        k_scale=layer.k_scale_float,
        v_scale=layer.v_scale_float,
    )
    return o.view(-1, layer.tp_q_head_num * layer.head_dim)

Decode에서는 BatchDecodeWithPagedKVCacheWrapper가 paged KV cache에서 각 요청에 대해 1-token 어텐션을 수행한다. k_scale_floatv_scale_float은 FP8 KV cache의 디퀀타이제이션 스케일이다. _float 접미사는 CUDA Graph 호환성을 위해 device-to-host 복사를 방지한다.

Multi-Item Scoring 지원

FlashInfer 백엔드는 Multi-Item Scoring이라는 특수 어텐션 패턴을 지원한다. 이는 여러 아이템이 delimiter로 구분된 시퀀스에서, 각 아이템 경계를 존중하는 어텐션 마스크를 적용한다.

@dataclass
class MultiItemScoringParams:
    prefix_len_ptr: Optional[torch.Tensor] = None
    token_pos_in_items_ptr: Optional[torch.Tensor] = None
    token_pos_in_items_len: int = 0
    max_item_len_ptr: Optional[torch.Tensor] = None

예를 들어, "질문 <delim> 답변A <delim> 답변B <delim>"에서 각 답변은 독립적으로 질문에만 어텐션하고, 다른 답변에는 어텐션하지 않는다.

FlashInfer vs FlashAttention 비교

항목 FlashAttention FlashInfer
주요 타겟 학습 + 추론 추론 serving 특화
Ragged batch varlen API로 지원 네이티브 지원
Paged KV cache page_table 파라미터 전용 wrapper
Plan-Execute 없음 (즉시 실행) 2단계 (인덱스 precompute)
Cascade attention 미지원 merge_state 기반
CUDA Graph 외부에서 관리 Wrapper 내장 지원
Multi-item scoring 미지원 네이티브 파라미터

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글