[flashinfer] FlashInfer의 Per-token NVFP4 Quantization 커널 최적화 분석
PR 링크: flashinfer-ai/flashinfer#3237 상태: Merged | 변경: +0 / -0
들어가며
LLM 추론 성능을 극대화하기 위해 FP4와 같은 저정밀도 양자화 기법이 필수적으로 사용되고 있습니다. FlashInfer는 고성능 커널 라이브러리로서, 최근 도입된 per-token nvfp4 quantization 커널의 효율성을 높이기 위한 최적화 작업을 진행했습니다. 이번 PR은 커널의 블록 사이즈 조정, Fast Math 경로의 유연한 제어, 그리고 메모리 접근 패턴 최적화를 통해 전반적인 처리 성능을 개선하는 데 목적이 있습니다.
코드 분석
1. 커널 블록 사이즈 및 Fast Math 제어 (csrc/nv_internal/cpp/kernels/quantization.cu)
기존 커널은 고정된 블록 사이즈를 사용했으나, 이를 128로 최적화하고 환경 변수를 통해 Fast Math 경로를 제어할 수 있도록 변경했습니다.
Before:
constexpr uint32_t BLOCK_SIZE = 256;
// ...
DISPATCH_NVP4_QUANT_AND_PER_TOKEN_SCALE_KERNEL(LINEAR);
After:
constexpr uint32_t BLOCK_SIZE = 128;
// ...
bool disableFP4QuantFastMath = tensorrt_llm::common::getEnvDisableFP4QuantFastMath();
// ...
if (disableFP4QuantFastMath) {
DISPATCH_NVP4_QUANT_AND_PER_TOKEN_SCALE_KERNEL(LINEAR, true);
} else {
DISPATCH_NVP4_QUANT_AND_PER_TOKEN_SCALE_KERNEL(LINEAR, false);
}
2. 공유 메모리 캐싱 전략 변경 (csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh)
기존에는 local amax를 캐싱하려 했으나 성능 향상이 미미했습니다. 대신, 입력 데이터를 공유 메모리에 캐싱하여 글로벌 메모리 재접근 비용을 줄이는 방식으로 변경했습니다.
Before:
if constexpr (CACHE_LOCAL_AMAX) {
localAmaxSmem[vecIdx] = localAmax;
}
After:
if constexpr (CACHE_INPUT) {
#pragma unroll
for (uint32_t bank = 0; bank < BANKS_PER_THREAD; ++bank) {
constexpr uint32_t STRIDE = BANKS_PER_THREAD + 1;
auto reg = reinterpret_cast<uint32_t*>(&vec_in)[bank];
smem[vecIdx * STRIDE + bank] = reg;
}
}
왜 이게 좋은가
- 메모리 효율성:
CACHE_INPUT을 통해 글로벌 메모리 읽기 횟수를 최소화했습니다. LLM 커널에서 병목은 대개 연산보다 메모리 대역폭에 있으므로, 공유 메모리를 활용한 데이터 재사용은 성능에 직접적인 영향을 줍니다. - 유연성:
TRTLLM_DISABLE_FP4_QUANT_FAST_MATH환경 변수를 도입함으로써, 특정 하드웨어나 정밀도 요구사항에 따라 Fast Math의 사용 여부를 결정할 수 있게 되었습니다. 이는 TransformerEngine과의 호환성을 맞추는 데 중요합니다. - 블록 사이즈 최적화:
BLOCK_SIZE를 128로 조정함으로써 GPU의 워프 스케줄링 효율을 높이고 레지스터 압박을 완화했습니다.
리뷰어 피드백 반영
리뷰 과정에서 int64_t 사용에 대한 논의가 있었습니다. uint32_t가 32GB 이상의 텐서를 다룰 수 있음에도 불구하고, int64_t 사용 시의 오버헤드가 무시할 수 있는 수준임을 확인하여 코드 안정성을 유지했습니다. 또한, 테스트 환경에서 환경 변수를 동적으로 제어할 수 있도록 하여 테스트 커버리지를 확보했습니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] DeepSeek-V4를 위한 MXFP4 Marlin MoE 커널 최적화 및 JIT 통합 분석
- [flashinfer] FlashInfer, CUDA 그래프 호환성을 높이고 성능을 최적화하다: TRT-LLM FMHA v2 통합 및 불필요한 H2D 제거
- [flashinfer] FlashInfer: Wide Vector 최적화와 1900줄의 코드 삭제로 달성한 성능 개선
- [flashinfer] FlashInfer의 DiT 최적화: SageAttention과 Int8/FP8 혼합 정밀도 커널 도입 분석
- [sglang] SGLang의 성능 향상을 위한 기본 Quantization 커널 최적화: v2 도입
PR Analysis 의 다른글
- 이전글 [flashinfer] FlashInfer, MoE 및 FP8 GEMM 성능 향상을 위한 커널 업데이트
- 현재글 : [flashinfer] FlashInfer의 Per-token NVFP4 Quantization 커널 최적화 분석
- 다음글 [vllm] vLLM DeepSeek v4 Fused Indexer Q 양자화 커널 최적화: CuteDSL을 활용한 성능 향상
댓글