본문으로 건너뛰기

[flashinfer] FlashInfer의 MoE Routing 성능 최적화: Batcher's Odd-Even Merge Sort 도입

PR 링크: flashinfer-ai/flashinfer#3476 상태: Merged | 변경: +108 / -26

들어가며

최근 대규모 언어 모델(LLM)에서 Mixture-of-Experts(MoE) 아키텍처가 널리 사용됨에 따라, 효율적인 라우팅(Routing) 연산의 중요성이 커지고 있습니다. 특히 MoE의 Top-K 선택 과정은 전체 추론 성능에 영향을 미치는 핵심 병목 구간 중 하나입니다. 이번 FlashInfer의 PR은 SM100+ 아키텍처를 타겟으로 한 64비트 리덕션 최적화와, 기존의 비효율적인 Odd-Even Transposition Sort를 Batcher's Odd-Even Merge Network로 교체하여 성능을 크게 향상시켰습니다.

코드 분석

1. 64비트 리덕션 최적화 (csrc/fused_moe/moeTopKFuncs.cuh)

기존에는 64비트 타입(float/int)에 대해 cg::reduce를 사용했으나, 새로운 구현에서는 SM100+ 아키텍처의 redux.sync.max.u32 명령어를 직접 활용합니다. 64비트 값을 상위 32비트(값)와 하위 32비트(인덱스)로 분리하여 두 번의 리덕션을 수행함으로써 성능을 최적화했습니다.

// Before
return cg::reduce(warp, compValIdx, cg::greater<TypeCmp>{});

// After
uint32_t hi = static_cast<uint32_t>(compValIdx >> 32);
uint32_t lo = static_cast<uint32_t>(compValIdx & 0xffffffffu);
asm volatile("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(maxHi) : "r"(hi));
uint32_t loContrib = (hi == maxHi) ? lo : 0u;
asm volatile("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(maxLo) : "r"(loContrib));
return (static_cast<TypeCmp>(maxHi) << 32) | static_cast<TypeCmp>(maxLo);

2. Batcher's Odd-Even Merge Sort 도입 (include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh)

기존의 Odd-Even Transposition Sort는 $O(N^2)$ 복잡도를 가지며 비-2의 거듭제곱(non-power-of-two) 크기에서 비효율적이었습니다. 이를 컴파일 타임에 결정되는 Batcher's 네트워크로 교체하여 정렬 성능을 개선했습니다.

// Before: O(N^2) 루프 기반 정렬
for (int pass = 0; pass < N; ++pass) { ... }

// After: 컴파일 타임 템플릿 재귀를 통한 Batcher's 네트워크
constexpr int P = NextPow2<N>::value;
topkSortBatcher<0, P, N, RedType>(topK);

왜 이게 좋은가

이번 최적화의 핵심은 하드웨어 가속기 활용컴파일 타임 최적화입니다.

  1. 성능 향상: K=22, N=2048, T=4096 환경에서 기존 대비 약 2.125배의 속도 향상을 기록했습니다. 이는 정렬 알고리즘의 복잡도를 낮추고 SASS(Shader Assembly) 수준에서 최적화된 명령어를 사용했기 때문입니다.
  2. 유연성: NextPow2 템플릿을 사용하여 임의의 K값에 대해서도 효율적인 정렬 네트워크를 생성합니다. 이는 런타임 패딩 없이도 정렬이 가능하게 하여 메모리 접근을 최소화합니다.
  3. 교훈: GPU 연산 최적화 시, 범용 라이브러리 함수(cg::reduce)에 의존하기보다 하드웨어 특화 명령어(asm volatile)를 적절히 섞어 쓰는 것이 특정 아키텍처에서 큰 성능 이득을 가져올 수 있음을 보여줍니다.

결론

이번 PR은 MoE 라우팅의 미세한 병목을 효과적으로 제거했습니다. 비록 E2E 성능에는 큰 차이가 없을지라도, 이러한 마이크로 최적화들이 모여 전체 시스템의 처리량을 극대화하는 밑거름이 됩니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글