[flashinfer] FlashInfer의 MoE Routing 성능 최적화: Batcher's Odd-Even Merge Sort 도입
PR 링크: flashinfer-ai/flashinfer#3476 상태: Merged | 변경: +108 / -26
들어가며
최근 대규모 언어 모델(LLM)에서 Mixture-of-Experts(MoE) 아키텍처가 널리 사용됨에 따라, 효율적인 라우팅(Routing) 연산의 중요성이 커지고 있습니다. 특히 MoE의 Top-K 선택 과정은 전체 추론 성능에 영향을 미치는 핵심 병목 구간 중 하나입니다. 이번 FlashInfer의 PR은 SM100+ 아키텍처를 타겟으로 한 64비트 리덕션 최적화와, 기존의 비효율적인 Odd-Even Transposition Sort를 Batcher's Odd-Even Merge Network로 교체하여 성능을 크게 향상시켰습니다.
코드 분석
1. 64비트 리덕션 최적화 (csrc/fused_moe/moeTopKFuncs.cuh)
기존에는 64비트 타입(float/int)에 대해 cg::reduce를 사용했으나, 새로운 구현에서는 SM100+ 아키텍처의 redux.sync.max.u32 명령어를 직접 활용합니다. 64비트 값을 상위 32비트(값)와 하위 32비트(인덱스)로 분리하여 두 번의 리덕션을 수행함으로써 성능을 최적화했습니다.
// Before
return cg::reduce(warp, compValIdx, cg::greater<TypeCmp>{});
// After
uint32_t hi = static_cast<uint32_t>(compValIdx >> 32);
uint32_t lo = static_cast<uint32_t>(compValIdx & 0xffffffffu);
asm volatile("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(maxHi) : "r"(hi));
uint32_t loContrib = (hi == maxHi) ? lo : 0u;
asm volatile("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(maxLo) : "r"(loContrib));
return (static_cast<TypeCmp>(maxHi) << 32) | static_cast<TypeCmp>(maxLo);
2. Batcher's Odd-Even Merge Sort 도입 (include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh)
기존의 Odd-Even Transposition Sort는 $O(N^2)$ 복잡도를 가지며 비-2의 거듭제곱(non-power-of-two) 크기에서 비효율적이었습니다. 이를 컴파일 타임에 결정되는 Batcher's 네트워크로 교체하여 정렬 성능을 개선했습니다.
// Before: O(N^2) 루프 기반 정렬
for (int pass = 0; pass < N; ++pass) { ... }
// After: 컴파일 타임 템플릿 재귀를 통한 Batcher's 네트워크
constexpr int P = NextPow2<N>::value;
topkSortBatcher<0, P, N, RedType>(topK);
왜 이게 좋은가
이번 최적화의 핵심은 하드웨어 가속기 활용과 컴파일 타임 최적화입니다.
- 성능 향상: K=22, N=2048, T=4096 환경에서 기존 대비 약 2.125배의 속도 향상을 기록했습니다. 이는 정렬 알고리즘의 복잡도를 낮추고 SASS(Shader Assembly) 수준에서 최적화된 명령어를 사용했기 때문입니다.
- 유연성:
NextPow2템플릿을 사용하여 임의의 K값에 대해서도 효율적인 정렬 네트워크를 생성합니다. 이는 런타임 패딩 없이도 정렬이 가능하게 하여 메모리 접근을 최소화합니다. - 교훈: GPU 연산 최적화 시, 범용 라이브러리 함수(
cg::reduce)에 의존하기보다 하드웨어 특화 명령어(asm volatile)를 적절히 섞어 쓰는 것이 특정 아키텍처에서 큰 성능 이득을 가져올 수 있음을 보여줍니다.
결론
이번 PR은 MoE 라우팅의 미세한 병목을 효과적으로 제거했습니다. 비록 E2E 성능에는 큰 차이가 없을지라도, 이러한 마이크로 최적화들이 모여 전체 시스템의 처리량을 극대화하는 밑거름이 됩니다.
참고 자료
- https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-redux
- https://developer.nvidia.com/blog/efficient-sorting-networks-on-gpus/
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [flashinfer] FlashInfer MLA 커널 최적화: num_heads < 128 환경에서의 성능 극대화
- [flashinfer] FlashInfer FP8 KV-Cache Prefill 성능 최적화: Repacking 기법을 통한 오버헤드 제거
- [vllm] [vLLM] MiniMax-M2 MoE Gate 최적화: Fused FP32 Kernel로 서빙 성능 32% 향상시키기
- [triton] [Triton] Persistent Matmul 성능을 13% 향상시킨 정교한 Shared Memory 계산 기법 분석
- [flashinfer] FlashInfer의 DeepSeek V4 Sparse MLA 최적화 분석
PR Analysis 의 다른글
- 이전글 [ray] Ray Data의 hash_partition 성능을 7배 향상시킨 최적화 전략
- 현재글 : [flashinfer] FlashInfer의 MoE Routing 성능 최적화: Batcher's Odd-Even Merge Sort 도입
- 다음글 [sglang] 실시간 RGB 전송 속도 향상을 위한 최적화 분석
댓글