본문으로 건너뛰기

[flashinfer] FlashInfer의 DiT 최적화: SageAttention과 Int8/FP8 혼합 정밀도 커널 도입 분석

PR 링크: flashinfer-ai/flashinfer#2711 상태: Merged | 변경: +None / -None

들어가며

최근 Diffusion Transformer(DiT) 모델의 크기가 커짐에 따라, 추론 성능을 극대화하기 위한 양자화 및 커널 최적화의 중요성이 커지고 있습니다. 특히 DiT는 기존 LLM과는 다른 어텐션 패턴과 정밀도 요구사항을 가집니다.

이번 FlashInfer PR에서는 DiT 모델에 최적화된 TRTLLM(TensorRT-LLM) 기반 커널들을 대거 추가했습니다. 핵심은 SageAttention 기법을 지원하기 위한 스케일링 팩터(Scaling Factors)의 도입과, Q/K/V 각각에 대해 서로 다른 데이터 타입(Int8, BFloat16, E4M3 등)을 유연하게 적용할 수 있는 혼합 정밀도(Mixed-Precision) 지원입니다.

코드 분석: 무엇이 바뀌었는가?

1. 캐시 키(Cache Key)의 확장 및 세밀한 커널 매칭

기본적으로 FlashInfer는 런타임에 적절한 커널을 선택하기 위해 TllmGenFmhaRunnerCache를 사용합니다. 기존에는 Q, KV, O의 데이터 타입만 고려했지만, 이제는 SageAttention 관련 파라미터들까지 캐시 키에 포함됩니다.

Before:

using Key = std::tuple<Data_type, Data_type, Data_type>;
// ...
Key key = std::make_tuple(q_data_type, kv_data_type, o_data_type);

After:

using Key = std::tuple<Data_type, Data_type, Data_type, Data_type, int, int, int, int>;
// ...
Key key = std::make_tuple(q_data_type, k_data_type, v_data_type, o_data_type, 
                          num_elts_sage_q, num_elts_sage_k, num_elts_sage_p, num_elts_sage_v);

이 변경을 통해 K와 V의 데이터 타입을 분리해서 관리할 수 있게 되었으며, SageAttention 블록별 스케일링에 필요한 요소 수(num_elts_sage_*)를 커널 선택의 기준으로 삼게 되었습니다.

2. SageAttention 스케일링 팩터 전달

SageAttention은 텐서 전체가 아닌 블록 단위로 스케일링을 수행하여 정밀도 손실을 최소화합니다. 이를 위해 trtllm_ragged_attention_launcher 함수에 관련 포인터들이 대거 추가되었습니다.

After (csrc/trtllm_fmha_kernel_launcher.cu):

// SageAttention scaling factors 전달
runner_params.ptrSageAttnSfsQ = sage_attn_sfs_q;
runner_params.ptrSageAttnSfsK = sage_attn_sfs_k;
runner_params.ptrSageAttnSfsP = sage_attn_sfs_p;
runner_params.ptrSageAttnSfsV = sage_attn_sfs_v;

이 포인터들은 커널 내부에서 Int8이나 FP8로 양자화된 행렬 곱셈(BMM1, BMM2)을 수행한 후, 다시 원래의 스케일로 복원하거나 소프트맥스 연산에 반영할 때 사용됩니다.

3. 데이터 타입 재해석(Reinterpretation) 전략

리뷰어 saltyminty와 작성자 xrq-phys 간의 논의에서 흥미로운 지점이 발견됩니다. 특정 커널에서는 하드웨어 가속기(TMA, Tensor Memory Accelerator)가 데이터를 E4M3로 읽어오지만, 실제 연산(Tensor Core)에서는 Int8이나 BFloat16으로 처리하도록 설계되었습니다.

xrq-phys: "TMA loads input 'as if E4m3' but TC treats them as Int8 / Bfloat16."

이는 NVIDIA Hopper 아키텍처의 특성을 활용하여 메모리 대역폭 효율과 연산 정밀도 사이의 균형을 맞추려는 고도의 최적화 기법입니다.

왜 이게 좋은 최적화인가?

  1. 메모리 대역폭 절감: V(Value) 텐서를 E4M3(FP8)로 유지함으로써 메모리 읽기 비용을 절반으로 줄이면서도, Q/K 연산은 Int8이나 BFloat16으로 수행하여 DiT 모델의 품질 저하를 막습니다.
  2. SageAttention 통합: 블록 단위 스케일링을 지원함으로써, 단순한 Per-tensor 양자화보다 훨씬 높은 정확도를 보장합니다. 이는 특히 이상치(Outlier)가 많은 어텐션 맵을 가진 모델에서 유리합니다.
  3. 유연한 API: dtypeKdtypeV를 분리함으로써, 향후 더 다양한 혼합 정밀도 조합(예: K는 FP8, V는 Int8 등)에 즉각적으로 대응할 수 있는 구조를 갖추었습니다.

결론

이번 PR은 FlashInfer가 단순한 LLM 가속기를 넘어, DiT와 같은 차세대 아키텍처와 SageAttention 같은 최신 양자화 기법을 수용하는 방향으로 진화하고 있음을 보여줍니다. 시니어 엔지니어로서 주목할 점은, 단순히 기능을 추가하는 것에 그치지 않고 TMA와 Tensor Core의 하드웨어적 특성을 고려한 데이터 재해석까지 설계에 녹여냈다는 점입니다.

이러한 저수준 최적화는 고성능 AI 추론 엔진을 개발할 때 반드시 참고해야 할 모범 사례입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글