[SGLang] FlashInfer: 래그드 텐서 어텐션 엔진
들어가며
FlashInfer는 LLM serving에 특화된 어텐션 커널 라이브러리다. FlashAttention과 동일한 IO-aware 원리를 사용하되, serving 환경에서 빈번한 가변 길이 배치(ragged batch)를 일급 시민(first-class citizen)으로 다룬다. SGLang에서 FlashInfer는 기본(default) 어텐션 백엔드로, Prefill의 Ragged KV Cache와 Decode의 Paged KV Cache를 모두 지원한다.
이 글에서는 python/sglang/srt/layers/attention/flashinfer_backend.py를 분석한다.
전체 구조
FlashInfer 백엔드는 세 종류의 wrapper를 조합하여 모든 forward 모드를 처리한다.
┌──────────────────────────────────────────────────────┐
│ FlashInferAttnBackend │
│ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ Prefill (Extend) │ │
│ │ │ │
│ │ ┌─ use_ragged=True ─────────────────────────┐ │ │
│ │ │ prefill_wrapper_ragged │ │ │
│ │ │ (BatchPrefillWithRaggedKVCacheWrapper) │ │ │
│ │ │ → 새 토큰의 Q,K,V로 직접 어텐션 │ │ │
│ │ │ + prefill_wrappers_paged (prefix cache) │ │ │
│ │ │ → merge_state()로 결과 합산 │ │ │
│ │ └────────────────────────────────────────────┘ │ │
│ │ │ │
│ │ ┌─ use_ragged=False ────────────────────────┐ │ │
│ │ │ prefill_wrappers_paged │ │ │
│ │ │ (BatchPrefillWithPagedKVCacheWrapper) │ │ │
│ │ │ → KV cache에서 전체 어텐션 │ │ │
│ │ └────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ Decode │ │
│ │ decode_wrappers │ │
│ │ (BatchDecodeWithPagedKVCacheWrapper) │ │
│ │ → Paged KV Cache에서 1-token 어텐션 │ │
│ └─────────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────────┘
FlashInfer의 Ragged Tensor 처리
Ragged tensor는 각 행의 길이가 다른 2D 텐서다. LLM serving에서 배치 내 시퀀스들은 거의 항상 다른 길이를 가진다. FlashInfer는 이를 패딩 없이 처리한다.
Padded Batch (기존 방식): Ragged Batch (FlashInfer):
┌──────────────────────┐ ┌──────────────┐
│ seq1 ████████░░░░░░░░│ │ seq1 ████████│
│ seq2 ████░░░░░░░░░░░░│ │ seq2 ████ │
│ seq3 ██████████████░░│ │ seq3 ██████████████│
└──────────────────────┘ └──────────────┘
max_len으로 패딩 → 메모리 낭비 indptr로 경계 표시 → 무낭비
초기화: Wrapper 구성
class FlashInferAttnBackend(AttentionBackend):
def __init__(self, model_runner, skip_prefill=False,
kv_indptr_buf=None, kv_last_page_len_buf=None,
init_new_workspace=False):
self.prefill_backend = "fa2"
self.decode_backend = "fa2"
self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype,
num_attention_heads=...,
num_kv_heads=...,
)
should_use_tensor_core는 GQA(Grouped Query Attention)에서 Q head 수와 KV head 수의 비율에 따라 tensor core 사용 여부를 결정한다. GQA 비율이 높으면 tensor core가 더 효율적이다.
Sliding Window / Cross Attention 이중 래퍼
if model_runner.sliding_window_size is not None:
self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
elif model_runner.model_config.is_encoder_decoder:
self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
else:
self.num_wrappers = 1
self.dispatch_reason = None
Sliding Window 모델(Mistral)이나 encoder-decoder 모델(Whisper)은 두 종류의 어텐션이 공존한다. 이를 위해 두 개의 wrapper를 생성하고, 레이어별로 적절한 wrapper를 선택한다.
def _get_wrapper_idx(self, layer: RadixAttention):
if self.num_wrappers == 1:
return 0
if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
return layer.sliding_window_size == -1 # SWA가 아닌 레이어 → 1
if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
return layer.is_cross_attention # cross attention 레이어 → 1
init_forward_metadata: Plan 단계
FlashInfer는 "plan-then-execute" 패턴을 사용한다. init_forward_metadata에서 인덱스 계산(plan)을 수행하고, forward_extend/forward_decode에서 실제 어텐션을 실행(execute)한다.
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_cpu,
forward_batch.seq_lens_sum,
decode_wrappers=self.decode_wrappers,
)
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
indices_updater_decode.update는 req_to_token 테이블에서 KV 인덱스(kv_indices)를 생성하고, decode_wrappers의 begin_forward(plan)를 호출한다. 이 plan 단계에서 GPU 커널의 block 할당과 메모리 접근 패턴이 결정된다.
forward_extend: Ragged + Paged 이중 어텐션
Extend(prefill) 모드에서 FlashInfer는 두 가지 경로를 지원한다.
def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True):
if not self.forward_metadata.use_ragged:
# 모든 KV를 캐시에 저장한 뒤, paged wrapper로 어텐션
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = prefill_wrapper_paged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=causal, sm_scale=layer.scaling,
)
else:
if self.forward_metadata.extend_no_prefix:
# prefix 없음: ragged wrapper만으로 처리
o = self.prefill_wrapper_ragged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True, sm_scale=layer.scaling,
)
else:
# prefix 있음: ragged(새 토큰) + paged(캐시) → merge
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(...)
o2, s2 = prefill_wrapper_paged.forward_return_lse(...)
o, _ = merge_state(o1, s1, o2, s2)
extend_no_prefix인 경우, 새로 입력된 토큰의 Q, K, V로 직접 ragged 어텐션을 수행한다. 이것이 가장 빠른 경로다.
Prefix cache가 존재하면 cascade attention을 사용한다. 새 토큰 간 ragged 어텐션(o1, s1)과 캐시된 prefix에 대한 paged 어텐션(o2, s2)을 각각 계산한 뒤, merge_state로 softmax 통계를 이용해 정확하게 합산한다.
forward_decode: Paged KV Cache 어텐션
def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True):
decode_wrapper = self.forward_metadata.decode_wrappers[
self._get_wrapper_idx(layer)
]
if k is not None:
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale_float,
v_scale=layer.v_scale_float,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
Decode에서는 BatchDecodeWithPagedKVCacheWrapper가 paged KV cache에서 각 요청에 대해 1-token 어텐션을 수행한다. k_scale_float과 v_scale_float은 FP8 KV cache의 디퀀타이제이션 스케일이다. _float 접미사는 CUDA Graph 호환성을 위해 device-to-host 복사를 방지한다.
Multi-Item Scoring 지원
FlashInfer 백엔드는 Multi-Item Scoring이라는 특수 어텐션 패턴을 지원한다. 이는 여러 아이템이 delimiter로 구분된 시퀀스에서, 각 아이템 경계를 존중하는 어텐션 마스크를 적용한다.
@dataclass
class MultiItemScoringParams:
prefix_len_ptr: Optional[torch.Tensor] = None
token_pos_in_items_ptr: Optional[torch.Tensor] = None
token_pos_in_items_len: int = 0
max_item_len_ptr: Optional[torch.Tensor] = None
예를 들어, "질문 <delim> 답변A <delim> 답변B <delim>"에서 각 답변은 독립적으로 질문에만 어텐션하고, 다른 답변에는 어텐션하지 않는다.
FlashInfer vs FlashAttention 비교
| 항목 | FlashAttention | FlashInfer |
|---|---|---|
| 주요 타겟 | 학습 + 추론 | 추론 serving 특화 |
| Ragged batch | varlen API로 지원 |
네이티브 지원 |
| Paged KV cache | page_table 파라미터 |
전용 wrapper |
| Plan-Execute | 없음 (즉시 실행) | 2단계 (인덱스 precompute) |
| Cascade attention | 미지원 | merge_state 기반 |
| CUDA Graph | 외부에서 관리 | Wrapper 내장 지원 |
| Multi-item scoring | 미지원 | 네이티브 파라미터 |
관련 포스트
참고
관련 포스트
- [flashinfer] FlashInfer FP8 KV-Cache Prefill 성능 최적화: Repacking 기법을 통한 오버헤드 제거
- [논문리뷰] LVSA: Training-Free Sparse Attention for Long Video Diffusion
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [flashinfer] FlashInfer MLA 커널 최적화: num_heads < 128 환경에서의 성능 극대화
SGLang 의 다른글
- 이전글 [SGLang] FlashAttention 백엔드: IO-aware 타일링 어텐션의 구현
- 현재글 : [SGLang] FlashInfer: 래그드 텐서 어텐션 엔진
- 다음글 [SGLang] Multi-head Latent Attention (MLA): KV 캐시 압축 어텐션
댓글