본문으로 건너뛰기

[vLLM] FlashInfer: LLM 서빙에 특화된 어텐션 엔진

들어가며

FlashAttention이 일반 어텐션 연산의 IO 최적화에 초점을 맞추었다면, FlashInfer는 LLM 서빙 시나리오에 특화된 어텐션 엔진이다. Prefill과 Decode를 별도 커널로 처리하고, FP8/FP4 KV 캐시를 네이티브로 지원하며, TRT-LLM 백엔드와의 통합까지 제공한다.

공식 문서

vLLM 공식 문서: Attention Backends

핵심 구조/코드 분석

FlashInfer 백엔드 구조

vllm/v1/attention/backends/flashinfer.py에서 FlashInfer 백엔드가 정의된다:

class FlashInferBackend(AttentionBackend):
    supported_dtypes: ClassVar[list[torch.dtype]] = [
        torch.float16, torch.bfloat16
    ]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto", "float16", "bfloat16",
        "fp8", "fp8_e4m3", "fp8_e5m2",
    ]

    @staticmethod
    def get_supported_kernel_block_sizes():
        return [16, 32, 64]  # Blackwell에서는 16, 32, 64만 지원

FlashAttention과 비교해서 눈에 띄는 차이점은 fp8_e5m2까지 지원한다는 점이다. FP8 KV 캐시는 메모리를 절반으로 줄여서 동일한 GPU에서 더 긴 시퀀스를 처리할 수 있게 한다.

Prefill/Decode 분리 처리

FlashInfer의 핵심은 prefill과 decode를 완전히 별도의 커널로 처리한다는 점이다:

from flashinfer import (
    BatchDecodeWithPagedKVCacheWrapper,
    BatchPrefillWithPagedKVCacheWrapper,
    BatchPrefillWithRaggedKVCacheWrapper,
    MultiLevelCascadeAttentionWrapper,
)
from flashinfer.decode import (
    fast_decode_plan,
    trtllm_batch_decode_with_kv_cache,
)
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
  • BatchDecodeWithPagedKVCacheWrapper: 디코딩 전용. 각 요청이 1개 토큰만 생성하는 패턴에 최적화되어 있다.
  • BatchPrefillWithPagedKVCacheWrapper: Paged KV 캐시에 대한 prefill. 이미 캐시에 있는 context를 읽는 데 사용된다.
  • BatchPrefillWithRaggedKVCacheWrapper: Ragged(비정형) KV에 대한 prefill. 새로 계산되는 토큰의 self-attention에 사용된다.

TRT-LLM 백엔드 통합

FlashInfer는 TensorRT-LLM의 어텐션 커널도 래핑한다:

trtllm_gen_workspace_buffer = None

def _get_trtllm_gen_workspace_buffer():
    global trtllm_gen_workspace_buffer
    if trtllm_gen_workspace_buffer is None:
        trtllm_gen_workspace_buffer = torch.zeros(
            envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
            dtype=torch.uint8, device="cuda"
        )
    return trtllm_gen_workspace_buffer

VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE 환경변수로 워크스페이스 크기를 조절할 수 있다.

FP8 KV 캐시 디퀀타이제이션

FP8 KV 캐시를 prefill에서 사용할 때는 Triton 커널로 실시간 디퀀타이제이션을 수행한다:

@triton.jit
def _trtllm_prefill_attn_kvfp8_dequant(
    kv_cache_ptr,
    block_tables_prefill_ptr,
    ...
    NUM_KV_HEADS: tl.constexpr,
):
    batch_idx = tl.program_id(0).to(tl.int64)
    mock_block_table_idx = tl.program_id(1).to(tl.int64)
    orig_page_num = tl.load(
        block_tables_prefill_ptr + batch_idx * block_table_stride
        + mock_block_table_idx
    ).to(tl.int64)
    # FP8 → FP16/BF16 변환
    fp8_k = tl.load(kv_cache_ptr + src_k)
    dequant_k = (fp8_k.to(tl.float32) * k_scale_val).to(dequant_dtype)
    tl.store(mock_kv_cache_ptr + dst_k, dequant_k)

블록 테이블을 따라가면서 FP8 값을 읽고, scale을 곱해서 원래 정밀도로 복원한다. mock KV 캐시에 연속적으로 재배치하여 prefill 커널이 효율적으로 접근할 수 있게 한다.

DCP (Decode Context Parallelism) 지원

class BatchDCPPrefillWrapper:
    def __init__(self, workspace_buffer, dcp_a2a=False):
        if dcp_a2a:
            self._dcp_combine = partial(
                dcp_a2a_lse_reduce, is_lse_base_on_e=False
            )
        else:
            self._dcp_combine = partial(
                cp_lse_ag_out_rs, is_lse_base_on_e=False
            )
        self._context = BatchPrefillWithPagedKVCacheWrapper(...)
        self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper(...)

긴 시퀀스의 KV 캐시를 여러 GPU에 분산(Context Parallelism)할 때, FlashInfer의 prefill 래퍼를 활용하여 분산 어텐션을 수행한다.

왜 이 설계인가

  1. Decode 최적화: LLM 서빙에서 대부분의 시간은 디코딩에 소요된다. FlashInfer는 decode 전용 커널로 이 경로를 극한까지 최적화한다.

  2. 양자화 KV 캐시 네이티브 지원: FP8, FP4 KV 캐시를 커널 레벨에서 직접 지원하여 메모리 절감과 성능을 동시에 달성한다.

  3. 유연한 블록 크기: 16, 32, 64 블록 크기를 지원하여 모델과 하드웨어에 맞는 최적 설정을 선택할 수 있다.

  4. 다중 백엔드 통합: FlashInfer 자체 커널과 TRT-LLM 커널을 상황에 따라 선택하여 최적의 성능을 달성한다.

FlashInfer는 "서빙에서 가장 중요한 것은 decode 레이턴시"라는 실용적 관점에서 설계된 엔진이다.

댓글

관련 포스트

vLLM 의 다른글