[vLLM] FlashInfer: LLM 서빙에 특화된 어텐션 엔진
들어가며
FlashAttention이 일반 어텐션 연산의 IO 최적화에 초점을 맞추었다면, FlashInfer는 LLM 서빙 시나리오에 특화된 어텐션 엔진이다. Prefill과 Decode를 별도 커널로 처리하고, FP8/FP4 KV 캐시를 네이티브로 지원하며, TRT-LLM 백엔드와의 통합까지 제공한다.
- 논문: FlashInfer: Efficient and Customizable Attention Engine for LLM Inference (arxiv 2501.01005)
- 공식 문서: https://docs.vllm.ai
공식 문서
vLLM 공식 문서: Attention Backends
핵심 구조/코드 분석
FlashInfer 백엔드 구조
vllm/v1/attention/backends/flashinfer.py에서 FlashInfer 백엔드가 정의된다:
class FlashInferBackend(AttentionBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16, torch.bfloat16
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "float16", "bfloat16",
"fp8", "fp8_e4m3", "fp8_e5m2",
]
@staticmethod
def get_supported_kernel_block_sizes():
return [16, 32, 64] # Blackwell에서는 16, 32, 64만 지원
FlashAttention과 비교해서 눈에 띄는 차이점은 fp8_e5m2까지 지원한다는 점이다. FP8 KV 캐시는 메모리를 절반으로 줄여서 동일한 GPU에서 더 긴 시퀀스를 처리할 수 있게 한다.
Prefill/Decode 분리 처리
FlashInfer의 핵심은 prefill과 decode를 완전히 별도의 커널로 처리한다는 점이다:
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper,
)
from flashinfer.decode import (
fast_decode_plan,
trtllm_batch_decode_with_kv_cache,
)
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
BatchDecodeWithPagedKVCacheWrapper: 디코딩 전용. 각 요청이 1개 토큰만 생성하는 패턴에 최적화되어 있다.BatchPrefillWithPagedKVCacheWrapper: Paged KV 캐시에 대한 prefill. 이미 캐시에 있는 context를 읽는 데 사용된다.BatchPrefillWithRaggedKVCacheWrapper: Ragged(비정형) KV에 대한 prefill. 새로 계산되는 토큰의 self-attention에 사용된다.
TRT-LLM 백엔드 통합
FlashInfer는 TensorRT-LLM의 어텐션 커널도 래핑한다:
trtllm_gen_workspace_buffer = None
def _get_trtllm_gen_workspace_buffer():
global trtllm_gen_workspace_buffer
if trtllm_gen_workspace_buffer is None:
trtllm_gen_workspace_buffer = torch.zeros(
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8, device="cuda"
)
return trtllm_gen_workspace_buffer
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE 환경변수로 워크스페이스 크기를 조절할 수 있다.
FP8 KV 캐시 디퀀타이제이션
FP8 KV 캐시를 prefill에서 사용할 때는 Triton 커널로 실시간 디퀀타이제이션을 수행한다:
@triton.jit
def _trtllm_prefill_attn_kvfp8_dequant(
kv_cache_ptr,
block_tables_prefill_ptr,
...
NUM_KV_HEADS: tl.constexpr,
):
batch_idx = tl.program_id(0).to(tl.int64)
mock_block_table_idx = tl.program_id(1).to(tl.int64)
orig_page_num = tl.load(
block_tables_prefill_ptr + batch_idx * block_table_stride
+ mock_block_table_idx
).to(tl.int64)
# FP8 → FP16/BF16 변환
fp8_k = tl.load(kv_cache_ptr + src_k)
dequant_k = (fp8_k.to(tl.float32) * k_scale_val).to(dequant_dtype)
tl.store(mock_kv_cache_ptr + dst_k, dequant_k)
블록 테이블을 따라가면서 FP8 값을 읽고, scale을 곱해서 원래 정밀도로 복원한다. mock KV 캐시에 연속적으로 재배치하여 prefill 커널이 효율적으로 접근할 수 있게 한다.
DCP (Decode Context Parallelism) 지원
class BatchDCPPrefillWrapper:
def __init__(self, workspace_buffer, dcp_a2a=False):
if dcp_a2a:
self._dcp_combine = partial(
dcp_a2a_lse_reduce, is_lse_base_on_e=False
)
else:
self._dcp_combine = partial(
cp_lse_ag_out_rs, is_lse_base_on_e=False
)
self._context = BatchPrefillWithPagedKVCacheWrapper(...)
self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper(...)
긴 시퀀스의 KV 캐시를 여러 GPU에 분산(Context Parallelism)할 때, FlashInfer의 prefill 래퍼를 활용하여 분산 어텐션을 수행한다.
왜 이 설계인가
-
Decode 최적화: LLM 서빙에서 대부분의 시간은 디코딩에 소요된다. FlashInfer는 decode 전용 커널로 이 경로를 극한까지 최적화한다.
-
양자화 KV 캐시 네이티브 지원: FP8, FP4 KV 캐시를 커널 레벨에서 직접 지원하여 메모리 절감과 성능을 동시에 달성한다.
-
유연한 블록 크기: 16, 32, 64 블록 크기를 지원하여 모델과 하드웨어에 맞는 최적 설정을 선택할 수 있다.
-
다중 백엔드 통합: FlashInfer 자체 커널과 TRT-LLM 커널을 상황에 따라 선택하여 최적의 성능을 달성한다.
FlashInfer는 "서빙에서 가장 중요한 것은 decode 레이턴시"라는 실용적 관점에서 설계된 엔진이다.
관련 포스트
vLLM 의 다른글
- 이전글 [vLLM] FlashAttention: IO-aware 타일링으로 어텐션 연산을 가속하는 원리
- 현재글 : [vLLM] FlashInfer: LLM 서빙에 특화된 어텐션 엔진
- 다음글 [vLLM] Multi-head Latent Attention: KV 캐시를 압축하는 DeepSeek의 어텐션
댓글