본문으로 건너뛰기

[onnxruntime] ONNX Runtime: MoE Router GEMV 최적화 및 Bias Fusion 구현

PR 링크: microsoft/onnxruntime#29170 상태: Merged | 변경: +779 / -20

들어가며

최신 대규모 언어 모델(LLM) 아키텍처인 Mixture of Experts(MoE)는 추론 시 라우팅 연산이 빈번하게 발생합니다. 특히 GPT-OSS-20B와 같은 모델에서는 각 QMoE 노드 직전에 작은 규모의 라우터 프로젝션(MatMulNBits + Add)이 수행되는데, 이 작은 연산들이 전체 추론 파이프라인에서 오버헤드로 작용합니다. 본 PR은 이 라우터 연산을 위한 특화된 CUDA 커널 경로를 생성하고, Add 연산을 MatMulNBits 내부로 융합(Fusion)하여 커널 실행 횟수를 줄임으로써 성능을 최적화합니다.

코드 분석

1. CUDA 커널 특화 (matmul_4bits.cu)

기존의 범용 MatMul4Bits 커널은 다양한 형상을 처리하기 위해 복잡한 로직을 거칩니다. 이번 변경에서는 특정 형상(M=1, N=32, K=2880)에 대해 최적화된 MatMulFloatInt4RouterKernel을 도입했습니다.

Before (Generic Path):

// 기존 범용 커널은 공유 메모리 스테이징 및 범용적인 루프 구조를 가짐
__global__ void MatMulFloatInt4Kernel(...) {
  // ... 복잡한 공유 메모리 로드 및 동기화 ...
}

After (Router Specialized Path):

// 라우터 전용 커널은 공유 메모리 없이 글로벌 메모리에서 직접 로드하여 오버헤드 제거
template <typename T, int BlockSize>
__global__ void MatMulFloatInt4RouterKernel(...) {
  // ... Bias를 연산 중간에 즉시 더하여 Add 커널 호출을 생략 ...
  if (threadIdx.x == 0) {
    output[n_id] = result + bias[n_id];
  }
}

2. 그래프 최적화 (matmul_nbits_fusion.cc)

MatMulNBitsAdd 노드가 연속될 때, 이를 하나의 MatMulNBits 노드로 합치는 그래프 변환기를 수정했습니다. 리뷰 과정에서 M=1인 경우에만 융합하도록 제약 조건을 강화하여 런타임 오류를 방지했습니다.

// M=1인 경우에만 Bias Fusion을 허용하도록 제약 조건 추가
if (A_shape.dim(-2) != 1) {
  return false; // M > 1이면 융합하지 않음
}

왜 이게 좋은가

  1. 커널 실행 오버헤드 제거: 라우터 연산 후 발생하는 Add 커널 호출을 제거했습니다. GPT-OSS-20B 모델에서 24개의 Add 노드를 제거하여 약 +0.2%의 처리량 향상을 확인했습니다.
  2. 메모리 접근 최적화: 라우터 특화 커널은 공유 메모리 스테이징을 생략하고 L2 캐시를 효율적으로 활용하도록 설계되어, 전체 라우터 GEMV 성능이 약 +1.6% ~ +1.8% 개선되었습니다.
  3. 안전한 최적화: 리뷰어들의 피드백을 반영하여 M > 1인 경우 융합을 방지하고, int64_t에서 int로의 캐스팅 시 발생할 수 있는 오버플로우 문제를 방지하는 등 견고한 코드를 작성했습니다.

교훈

  • 특화(Specialization)의 힘: 범용 커널은 유지보수에 좋지만, 모델의 특정 병목 지점(Hot path)에서는 고정된 형상에 최적화된 커널이 압도적인 성능을 냅니다.
  • 그래프 융합 시 제약 조건: 그래프 변환기 작성 시, 런타임 커널이 지원하는 제약 조건(예: M=1)을 반드시 사전에 검증해야 런타임 실패를 방지할 수 있습니다.

리뷰어 피드백 반영

  • std::getenv 대신 ORT 전용 GetEnvironmentVar를 사용하여 플랫폼 간 일관성을 확보했습니다.
  • LaunchMatMulNBitsBiasAdd에서 그리드 크기 오버플로우를 방지하기 위해 명시적인 제한 검사를 추가했습니다.
  • IsSupportedRouterGemvShape를 GPT-OSS-20B 전용으로 제한하여 불필요한 테스트 범위를 줄이고 명확성을 높였습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글