[onnxruntime] ONNX Runtime: MoE Router GEMV 최적화 및 Bias Fusion 구현
PR 링크: microsoft/onnxruntime#29170 상태: Merged | 변경: +779 / -20
들어가며
최신 대규모 언어 모델(LLM) 아키텍처인 Mixture of Experts(MoE)는 추론 시 라우팅 연산이 빈번하게 발생합니다. 특히 GPT-OSS-20B와 같은 모델에서는 각 QMoE 노드 직전에 작은 규모의 라우터 프로젝션(MatMulNBits + Add)이 수행되는데, 이 작은 연산들이 전체 추론 파이프라인에서 오버헤드로 작용합니다. 본 PR은 이 라우터 연산을 위한 특화된 CUDA 커널 경로를 생성하고, Add 연산을 MatMulNBits 내부로 융합(Fusion)하여 커널 실행 횟수를 줄임으로써 성능을 최적화합니다.
코드 분석
1. CUDA 커널 특화 (matmul_4bits.cu)
기존의 범용 MatMul4Bits 커널은 다양한 형상을 처리하기 위해 복잡한 로직을 거칩니다. 이번 변경에서는 특정 형상(M=1, N=32, K=2880)에 대해 최적화된 MatMulFloatInt4RouterKernel을 도입했습니다.
Before (Generic Path):
// 기존 범용 커널은 공유 메모리 스테이징 및 범용적인 루프 구조를 가짐
__global__ void MatMulFloatInt4Kernel(...) {
// ... 복잡한 공유 메모리 로드 및 동기화 ...
}
After (Router Specialized Path):
// 라우터 전용 커널은 공유 메모리 없이 글로벌 메모리에서 직접 로드하여 오버헤드 제거
template <typename T, int BlockSize>
__global__ void MatMulFloatInt4RouterKernel(...) {
// ... Bias를 연산 중간에 즉시 더하여 Add 커널 호출을 생략 ...
if (threadIdx.x == 0) {
output[n_id] = result + bias[n_id];
}
}
2. 그래프 최적화 (matmul_nbits_fusion.cc)
MatMulNBits와 Add 노드가 연속될 때, 이를 하나의 MatMulNBits 노드로 합치는 그래프 변환기를 수정했습니다. 리뷰 과정에서 M=1인 경우에만 융합하도록 제약 조건을 강화하여 런타임 오류를 방지했습니다.
// M=1인 경우에만 Bias Fusion을 허용하도록 제약 조건 추가
if (A_shape.dim(-2) != 1) {
return false; // M > 1이면 융합하지 않음
}
왜 이게 좋은가
- 커널 실행 오버헤드 제거: 라우터 연산 후 발생하는
Add커널 호출을 제거했습니다. GPT-OSS-20B 모델에서 24개의Add노드를 제거하여 약+0.2%의 처리량 향상을 확인했습니다. - 메모리 접근 최적화: 라우터 특화 커널은 공유 메모리 스테이징을 생략하고 L2 캐시를 효율적으로 활용하도록 설계되어, 전체 라우터 GEMV 성능이 약
+1.6% ~ +1.8%개선되었습니다. - 안전한 최적화: 리뷰어들의 피드백을 반영하여
M > 1인 경우 융합을 방지하고,int64_t에서int로의 캐스팅 시 발생할 수 있는 오버플로우 문제를 방지하는 등 견고한 코드를 작성했습니다.
교훈
- 특화(Specialization)의 힘: 범용 커널은 유지보수에 좋지만, 모델의 특정 병목 지점(Hot path)에서는 고정된 형상에 최적화된 커널이 압도적인 성능을 냅니다.
- 그래프 융합 시 제약 조건: 그래프 변환기 작성 시, 런타임 커널이 지원하는 제약 조건(예: M=1)을 반드시 사전에 검증해야 런타임 실패를 방지할 수 있습니다.
리뷰어 피드백 반영
std::getenv대신 ORT 전용GetEnvironmentVar를 사용하여 플랫폼 간 일관성을 확보했습니다.LaunchMatMulNBitsBiasAdd에서 그리드 크기 오버플로우를 방지하기 위해 명시적인 제한 검사를 추가했습니다.IsSupportedRouterGemvShape를 GPT-OSS-20B 전용으로 제한하여 불필요한 테스트 범위를 줄이고 명확성을 높였습니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://onnxruntime.ai/docs/api/c/index.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [onnxruntime] ONNX Runtime QMoE SwiGLU GEMV 최적화: Split-K2 커널로 LLM 추론 가속화
- [onnxruntime] ONNX Runtime CUDA MoE: 소규모 배치 디코딩을 위한 SoftmaxTopK 라우터 최적화
- [flashinfer] FlashInfer의 MoE Routing 성능 최적화: Batcher's Odd-Even Merge Sort 도입
- [vllm] [vLLM] MiniMax-M2 MoE Gate 최적화: Fused FP32 Kernel로 서빙 성능 32% 향상시키기
- [flashinfer] FlashInfer의 TRTLLM-Gen MoE 라우팅 최적화: 레지스터 압박 해소와 성능 극대화
PR Analysis 의 다른글
- 이전글 [sglang] [HunyuanVideo] Sequence Parallelism 최적화: Text Token Sharding으로 성능 한계 돌파하기
- 현재글 : [onnxruntime] ONNX Runtime: MoE Router GEMV 최적화 및 Bias Fusion 구현
- 다음글 [axolotl] Axolotl, 대규모 언어 모델 학습 시 메모리 부족 문제 해결: 효율적인 데이터셋 처리 개선
댓글