본문으로 건너뛰기

[flashinfer] FlashInfer FP8 KV-Cache Prefill 성능 최적화: Repacking 기법을 통한 오버헤드 제거

PR 링크: flashinfer-ai/flashinfer#3485 상태: Merged | 변경: +328 / -63

들어가며

최근 LLM 추론에서 메모리 대역폭을 절약하기 위해 FP8 KV-cache를 도입하는 사례가 늘고 있습니다. 하지만 FlashInfer의 기존 FA2 BatchPrefill 커널에서는 FP8 데이터를 매번 레지스터 수준에서 BF16으로 dequantization하는 과정이 필요했습니다. 이는 연산 효율을 높이는 대신 ALU와 shuffle 오버헤드를 발생시켜, 결과적으로 BF16 대비 성능 저하를 초래했습니다. 본 PR은 이 문제를 해결하기 위해 FP8 데이터를 한 번만 BF16으로 변환하여 공유 메모리(Shared Memory)에 캐싱하는 'Repacking' 기법을 도입했습니다.

코드 분석

1. include/flashinfer/attention/prefill.cuh: Staging Buffer 도입

핵심 변경 사항은 FP8 데이터를 BF16으로 미리 변환해 두는 staging buffer를 공유 메모리에 할당하는 것입니다. std::conditional_t를 사용하여 필요한 경우에만 메모리를 할당하도록 설계되었습니다.

Before:

// 기존에는 별도의 staging buffer 없이 매번 dequantization 수행

After:

static constexpr bool USE_KV_REPACK =
    (sizeof(DTypeKV) == 1) && !is_fp4_type_v<DTypeKV> && (HEAD_DIM_VO != 64);
static constexpr uint32_t REPACK_BUF_ELEMS =
    CTA_TILE_KV * (HEAD_DIM_QK > HEAD_DIM_VO ? HEAD_DIM_QK : HEAD_DIM_VO);
alignas(16) std::conditional_t<USE_KV_REPACK, DTypeQ[REPACK_BUF_ELEMS], DTypeQ[1]> kv_smem_repack;

2. repack_fp8_tile_to_bf16 함수 추가

이 함수는 FP8 타일을 읽어와 BF16으로 변환한 뒤, ldmatrix가 효율적으로 읽을 수 있는 레이아웃으로 재배치합니다. shuffle-free한 vectorized pass를 통해 성능을 극대화했습니다.

template <typename KTraits, uint32_t HEAD_DIM>
__device__ __forceinline__ void repack_fp8_tile_to_bf16(...) {
  // ... vectorized cast ...
  vec_cast<DTypeQ, DTypeKV>::template cast<16>(conv, (DTypeKV*)&packed);
  dst[get_permuted_offset<SWIZZLE, BF16_COLS>(row, 2 * col)] = *(b128_t*)&conv[0];
  // ...
}

왜 이게 좋은가

이 최적화의 핵심은 '연산의 중복 제거'입니다. 기존에는 매 타일마다 dequantization을 수행했으나, 이제는 한 번의 패스로 BF16 staging buffer에 저장해두고, 이후 연산에서는 ldmatrix를 통해 shuffle 없이 데이터를 가져옵니다.

  • 성능 향상: RTX PRO 6000 Blackwell 기준, head_dim=128에서 약 1.1배에서 1.3배의 성능 향상을 보였습니다.
  • 메모리 효율: std::conditional_t를 사용하여 USE_KV_REPACK이 false인 경우(BF16, FP16 등)에는 메모리 오버헤드가 거의 발생하지 않도록 설계되었습니다.
  • 교훈: GPU 커널 최적화에서 레지스터 수준의 반복적인 데이터 변환은 병목이 되기 쉽습니다. 공유 메모리를 활용한 staging buffer 전략은 메모리 대역폭과 연산 유닛(Tensor Cores) 사이의 균형을 맞추는 데 매우 효과적입니다.

리뷰어 피드백 반영

리뷰 과정에서 conv 배열의 정렬(alignment) 문제가 지적되었습니다. b128_t로 재해석하여 공유 메모리에 저장할 때 16바이트 정렬이 보장되지 않으면 성능 저하나 오류가 발생할 수 있는데, 이를 alignas(16)을 추가하여 해결했습니다.

References

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글