본문으로 건너뛰기

[vllm] vLLM, DeepSeek V4 모델의 저지연을 위한 RMSNorm과 라우터 GEMV 연산 융합으로 성능 극대화

PR 링크: vllm-project/vllm#41263 상태: Merged | 변경: +0 / -0

들어가며

최근 대규모 언어 모델(LLM) 추론 시스템에서 지연 시간(latency)은 사용자 경험과 직결되는 매우 중요한 성능 지표입니다. 특히 실시간 상호작용이 중요한 애플리케이션에서는 낮은 지연 시간을 달성하는 것이 필수적입니다. vLLM은 LLM 추론을 위한 고성능 라이브러리로, 지속적으로 최적화 기법을 도입하여 LLM 추론의 한계를 넓혀가고 있습니다. 이번 PR (#2506)은 DeepSeek V4 모델에서 발생하는 특정 병목 구간을 개선하여 저지연 시나리오에서의 성능을 크게 향상시키는 것을 목표로 합니다.

이 PR의 핵심은 DeepSeek V4 모델의 Mixture-of-Experts (MoE) 레이어에서 발생하는 두 가지 연산, 즉 RMSNorm과 라우터(router) GEMV(General Matrix-Vector Multiplication) 연산을 하나의 CUDA 커널로 융합(fuse)하는 것입니다. 기존에는 이 두 연산이 별도의 커널로 실행되어 여러 커널 호출 오버헤드와 중간 결과 저장을 위한 메모리 접근이 발생했습니다. 이를 단일 커널로 통합함으로써, 커널 실행 시간을 단축하고 메모리 대역폭 사용량을 줄여 전반적인 추론 지연 시간을 개선하고자 합니다.

본 글에서는 이 PR의 코드 변경 사항을 상세히 분석하고, 왜 이러한 융합이 성능 향상으로 이어지는지, 그리고 실제 성능 개선 수치는 어떠한지 살펴보겠습니다.

코드 분석

이번 PR은 주로 CUDA 커널 구현과 관련 파이썬/C++ 바인딩, 그리고 벤치마크 코드를 수정했습니다.

1. CMakeLists.txt: 새로운 CUDA 커널 빌드 설정

먼저, 새로운 융합 커널을 빌드하기 위한 설정이 CMakeLists.txt에 추가되었습니다.

--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1222,6 +1222,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
       CUDA_ARCHS "${DSV3_ROUTER_GEMM_ARCHS}")
     list(APPEND VLLM_MOE_EXT_SRC "${DSV3_ROUTER_GEMM_SRC}")
     message(STATUS "Building DSV3 router GEMM kernel for archs: ${DSV3_ROUTER_GEMM_ARCHS}")
+
+    # DeepSeek V4 fused RMSNorm + router GEMV - same arch gating as DSV3.
+    set(DSV4_NORM_ROUTER_GEMM_SRC
+      "csrc/moe/dsv4_norm_router_gemm_entry.cu"
+      "csrc/moe/dsv4_norm_router_gemm_kernel.cu")
+    set_gencode_flags_for_srcs(
+      SRCS "${DSV4_NORM_ROUTER_GEMM_SRC}"
+      CUDA_ARCHS "${DSV3_ROUTER_GEMM_ARCHS}")
+    list(APPEND VLLM_MOE_EXT_SRC "${DSV4_NORM_ROUTER_GEMM_SRC}")
+    message(STATUS "Building DSV4 norm+router GEMV kernel for archs: ${DSV3_ROUTER_GEMM_ARCHS}")
   else()
     message(STATUS "Not building DSV3 router GEMM kernel as no compatible archs found"
                    " (requires SM90+ and CUDA >= 12.0)")

DSV4_NORM_ROUTER_GEMM_SRC 변수에 새로운 커널 파일(dsv4_norm_router_gemm_entry.cu, dsv4_norm_router_gemm_kernel.cu)이 추가되었습니다. 이는 기존 DSV3 라우터 GEMM 커널과 동일한 아키텍처(SM90+ 및 CUDA 12.0 이상)를 타겟으로 하며, 빌드 시스템에 새로운 융합 커널을 포함하도록 설정합니다. 이로써 컴파일 시점에 이 커널들이 vLLM 라이브러리에 포함될 수 있게 됩니다.

2. benchmarks/kernels/benchmark_norm_router_gemm.py: 벤치마크 및 정확도 검증

이 파일은 새로운 융합 커널의 성능과 정확성을 검증하기 위해 새로 추가되었습니다.

# ... (imports and constants) ...

def unfused_norm_router_gemm(
    x: torch.Tensor,
    norm_weight: torch.Tensor,
    gate_weight: torch.Tensor,
    eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    # Call ``_C::rms_norm`` directly (mirroring ``_dsv4_pro_norm_gate``'s
    # fallback path) so the benchmarked baseline doesn't inherit any
    # Python wrapper overhead or risk falling through to the native
    # eager-primitive ``RMSNorm.forward_native`` path.
    normed = torch.empty_like(x)
    torch.ops._C.rms_norm(normed, x, norm_weight, eps)
    logits = vllm_ops.dsv3_router_gemm(normed, gate_weight, torch.float32)
    return normed, logits

def fused_norm_router_gemm(
    x: torch.Tensor,
    norm_weight: torch.Tensor,
    gate_weight: torch.Tensor,
    eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    return vllm_ops.dsv4_norm_router_gemm(x, norm_weight, gate_weight, eps)

# ... (input generation and calculate_diff function) ...

def get_benchmark():
    # ... (benchmark setup) ...
    @triton.testing.perf_report(
        triton.testing.Benchmark(
            x_names=["num_tokens"],
            x_vals=list(range(1, 17)),
            line_arg="provider",
            line_vals=["unfused", "fused"],
            line_names=["unfused (rms+dsv3)", "fused (dsv4)"],
            styles=[("green", "-"), ("red", "-")],
            ylabel="us",
            plot_name=f"norm-router-gemm-E{num_experts}-H{HIDDEN_SIZE}",
            args={},
        )
    )
    def benchmark(num_tokens, provider):
        # ... (benchmarking logic) ...
        return 1000 * ms, 1000 * max_ms, 1000 * min_ms

    return benchmark

# ... (main function) ...
  • unfused_norm_router_gemm: 이 함수는 기존 방식을 시뮬레이션합니다. 먼저 torch.ops._C.rms_norm을 호출하여 RMSNorm을 계산하고, 그 결과를 vllm_ops.dsv3_router_gemm에 전달하여 라우터 로짓을 계산합니다. 이는 두 번의 커널 호출과 중간 결과(normed) 저장을 포함합니다.
  • fused_norm_router_gemm: 이 함수는 새로 구현된 vllm_ops.dsv4_norm_router_gemm을 호출합니다. 이 단일 커널은 입력 텐서 x로부터 RMSNorm 계산과 라우터 로짓 계산을 동시에 수행합니다.
  • calculate_diff: 이 함수는 융합된 커널과 분리된 커널의 출력 결과가 동일한지 정확도를 검증합니다. torch.allclose를 사용하여 두 결과 간의 최대 절대 오차(max absolute difference)를 확인하며, 이는 약 1 ULP(Unit in the Last Place) 수준으로 일치해야 함을 명시합니다. 이는 융합이 연산의 정확성을 해치지 않음을 보장합니다.
  • get_benchmark: Triton 라이브러리를 사용하여 융합된 커널과 분리된 커널의 성능을 측정합니다. num_tokens를 변화시키면서 각 연산의 지연 시간을 측정하고 그래프로 시각화합니다. 이 벤치마크는 DeepSeek V4 Pro의 특정 하이퍼파라미터(HIDDEN_SIZE = 7168, NUM_EXPERTS_CHOICES = (384,))에 맞춰져 있습니다.

3. csrc/moe/dsv4_norm_router_gemm.h: 융합 커널 헤더

이 헤더 파일은 융합된 CUDA 커널의 인터페이스를 정의합니다.

/*
 * Fused RMSNorm + router GEMV for DeepSeek V4.
 * ... (comments explaining the logic) ...
 */

#pragma once

#include <cuda_bf16.h>
#include <cuda_runtime.h>

#include "dsv3_router_gemm_utils.h"

template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeNormRouterGemm(float* logits, __nv_bfloat16* normed_x, T const* x,
                          T const* norm_weight, T const* gate_weight, float eps,
                          cudaStream_t stream);

핵심은 invokeNormRouterGemm 템플릿 함수입니다. 이 함수는 입력 데이터 타입 T, 토큰 수 kNumTokens, 전문가 수 kNumExperts, 히든 차원 kHiddenDim을 템플릿 인자로 받습니다. 이 템플릿 메커니즘을 통해 컴파일 시점에 특정 하드웨어 및 모델 구성에 최적화된 커널 코드를 생성할 수 있습니다. 특히, DeepSeek V4 Pro의 경우 kDsv4NumExperts = 384kDsv4HiddenDim = 7168로 고정되어 있습니다. 이 커널은 RMSNorm 계산 결과를 전역 메모리(normed_x)에 쓰고, 동시에 라우터 로짓(logits)을 계산합니다. 로짓은 float 타입으로 출력되며, 이는 DeepSeek V4 모델의 요구사항과 일치합니다.

4. csrc/moe/dsv4_norm_router_gemm_entry.cu: PyTorch Op Wrapper

이 파일은 위에서 정의된 융합 커널을 PyTorch 텐서와 연동하기 위한 C++ Op wrapper를 구현합니다.

/*
 * TORCH op entry for the fused RMSNorm + router GEMV kernel
 * (DeepSeek V4 Pro). ...
 */

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>

#include <cuda_bf16.h>
#include <cuda_runtime.h>

#include "core/registration.h"
#include "dsv4_norm_router_gemm.h"

namespace {

// DSV4-Pro hard-coded shape constants.
constexpr int kDsv4NumExperts = 384;
constexpr int kDsv4HiddenDim = 7168;

// ... (LoopUnroller template for dispatching invokeNormRouterGemm) ...

}  // namespace

void dsv4_norm_router_gemm(at::Tensor& logits,    // [num_tokens, E] fp32
                           at::Tensor& normed_x,  // [num_tokens, H] bf16
                           at::Tensor const& x,   // [num_tokens, H] bf16
                           at::Tensor const& norm_weight,  // [H] bf16
                           at::Tensor const& gate_weight,  // [E, H] bf16
                           float eps) {
    const int num_tokens = x.size(0);
    // ... (input validation and stream retrieval) ...

    // Dispatch to the templated kernel based on num_tokens.
    // Note: kDsv4NumExperts and kDsv4HiddenDim are fixed at compile time.
    LoopUnroller<1, 17>::unroll(num_tokens, 
                               logits.data_ptr<float>(),
                               normed_x.data_ptr<__nv_bfloat16>(),
                               x.data_ptr<__nv_bfloat16>(),
                               norm_weight.data_ptr<__nv_bfloat16>(),
                               gate_weight.data_ptr<__nv_bfloat16>(),
                               eps,
                               stream);
}

// Register the PyTorch op
TORCH_LIBRARY(vllm_ops, m) {
    m.def(
        "dsv4_norm_router_gemm(
            Tensor logits,
            Tensor normed_x,
            Tensor x,
            Tensor norm_weight,
            Tensor gate_weight,
            float eps
        ) -> void",
        &dsv4_norm_router_gemm,
        py::call_guard<py::gil_scoped_release>());
}

dsv4_norm_router_gemm 함수는 PyTorch 텐서를 인자로 받아, CUDA 스트림을 얻고, num_tokens 값에 따라 LoopUnroller를 통해 적절한 invokeNormRouterGemm 템플릿 인스턴스를 호출합니다. LoopUnrollernum_tokens가 1부터 16까지의 범위에 있을 때 컴파일 타임에 최적화된 커널을 선택하도록 합니다. 이 함수는 PyTorch의 TORCH_LIBRARY 매크로를 통해 vllm_ops.dsv4_norm_router_gemm이라는 이름으로 등록되어 파이썬 코드에서 직접 호출될 수 있게 됩니다.

5. csrc/moe/dsv4_norm_router_gemm_kernel.cu: 실제 융합 커널 구현

이 파일에는 실제 융합 연산이 수행되는 CUDA 커널 코드가 포함되어 있습니다. 이 커널은 SM90+ 아키텍처(Ampere 이상)를 타겟으로 하며, Triton의 pdl (Parallel Distribution Language)과 유사한 기법을 사용하여 효율성을 극대화합니다.

/*
 * Fused RMSNorm + router GEMV for DeepSeek V4.
 * ... (detailed explanation of the algorithm) ...
 */

#include "dsv4_norm_router_gemm.h"

// ... (utility functions for warp reduction, etc.) ...

template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeNormRouterGemm(float* logits, __nv_bfloat16* normed_x, T const* x,
                          T const* norm_weight, T const* gate_weight, float eps,
                          cudaStream_t stream) {
    // ... (kernel launch configuration: grid, block dimensions) ...

    // Kernel launch configuration calculation
    // Grid dimensions: number of blocks needed to cover all tokens and experts
    // Block dimensions: threads per block, warp size, etc.
    dim3 grid_dim(...);
    dim3 block_dim(...);

    // Launch the kernel
    // The kernel itself will perform the fused RMSNorm and GEMV operations.
    // It utilizes shared memory and warp-level operations for efficiency.
    // Example kernel invocation (actual implementation is more complex):
    // fused_norm_router_gemm_kernel<T, kNumTokens, kNumExperts, kHiddenDim>(
    //    logits, normed_x, x, norm_weight, gate_weight, eps, ...
    // );
    // cudaError_t err = cudaGetLastError();
    // if (err != cudaSuccess) { /* handle error */ }
    // cudaStreamSynchronize(stream);
}

// Explicit instantiations for common configurations (e.g., DSV4-Pro)
// template void invokeNormRouterGemm<__nv_bfloat16, 16, 384, 7168>(...);
// ... other instantiations ...

이 커널 구현의 핵심은 다음과 같습니다:

  1. 알고리즘: RMSNorm 계산(x[m,k] * rsqrt(mean(x[m]^2) + eps) * norm_weight[k])과 라우터 GEMV 계산(sum_k(normed_x[m,k] * gate_weight[n,k]))을 하나의 커널 내에서 수행합니다. 이를 위해 normed_x가 전역 메모리에 기록되기 전에 GEMV 계산을 수행할 수 있도록 대수적 항등식(logits[m,n] = rsqrt[m] * sum_k(x[m,k] * nw[k] * gw[n,k]))을 활용합니다. 이는 normed_x가 메모리에 쓰여지고 다시 읽혀오는 과정을 생략하여 지연 시간을 줄입니다.
  2. 데이터 타입: 입력 x, norm_weight, gate_weight__nv_bfloat16 (BF16) 타입을 사용하며, 중간 계산 및 최종 logits 출력은 float 타입을 사용합니다. normed_x 또한 BF16으로 출력됩니다.
  3. 최적화 기법: SM90+ 아키텍처의 기능을 활용하여 워프(warp) 단위 연산, 공유 메모리(shared memory) 활용, 그리고 효율적인 병렬 처리(PDL - Parallel Distribution Language)를 통해 성능을 극대화합니다. 특히, GEMV 부분은 DSV3 라우터 GEMM 커널과 유사한 최적화 기법을 따릅니다.
  4. 템플릿 특수화: invokeNormRouterGemm 함수는 다양한 kNumTokens, kNumExperts, kHiddenDim 조합에 대해 템플릿으로 정의되며, 컴파일 시점에 실제 사용될 설정값(예: DSV4 Pro의 16개 토큰, 384개 전문가, 7168 히든 차원)에 맞게 특수화되어 최적의 코드를 생성합니다.

왜 이게 좋은가?

이 PR은 다음과 같은 이유로 좋은 최적화/개선이라고 할 수 있습니다.

  1. 지연 시간 감소 (Latency Reduction):

    • 커널 융합: 두 개의 독립적인 CUDA 커널 호출(RMSNorm, GEMV)을 하나의 커널로 통합함으로써, 커널 실행 자체에 소요되는 오버헤드를 줄입니다. 각 커널 호출에는 GPU 커널 디스패치 및 동기화 비용이 수반되는데, 이를 제거하면 상당한 지연 시간 단축 효과를 얻을 수 있습니다.
    • 메모리 대역폭 절약: 기존 방식에서는 RMSNorm 연산의 중간 결과인 normed_x가 전역 메모리에 쓰여지고, 이후 GEMV 연산에서 다시 읽혀와야 했습니다. 융합 커널은 이 중간 결과를 전역 메모리에 쓰지 않고 공유 메모리나 레지스터 내에서 직접 GEMV 연산에 활용함으로써, 메모리 읽기/쓰기 횟수를 줄여 메모리 대역폭 병목 현상을 완화합니다. 이는 특히 메모리 대역폭이 제한적인 환경에서 큰 성능 향상을 가져옵니다.
  2. 처리량 향상 (Throughput Improvement):

    • 성능 테스트 결과에 따르면, 여러 시나리오에서 초당 처리량(Req/s, Tok/s)이 향상되었습니다. 예를 들어, Conc=1일 때 Total tok/s가 610.22에서 622.87로, Conc=4일 때 2031.17에서 2075.87로 증가했습니다. 이는 지연 시간 감소와 메모리 효율성 증대가 결합된 결과입니다.
  3. 정확성 보장: benchmarks/kernels/benchmark_norm_router_gemm.py 파일의 calculate_diff 함수를 통해 새로운 융합 커널의 결과가 기존 분리된 커널의 결과와 거의 동일함(약 1 ULP 이내)을 검증했습니다. 이는 성능 향상이 모델의 정확성을 해치지 않음을 보장합니다.

  4. 특정 모델 최적화: DeepSeek V4와 같이 특정 구조(MoE, 특정 차원)를 가진 모델에 대한 깊이 있는 분석을 통해, 해당 모델의 병목 지점을 정확히 파악하고 최적화된 커널을 개발했습니다. 이는 범용적인 최적화보다 훨씬 큰 성능 향상을 가져올 수 있습니다.

성능 수치 요약 (PR 제공 데이터 기준):

  • 처리량 (Tok/s): 대부분의 Concatenation(Conc) 값에서 향상되었습니다. 예를 들어 Conc=1에서 약 2% 증가, Conc=4에서 약 2.2% 증가했습니다.
  • 지연 시간 (P50 TTFT - Time To First Token): 일부 시나리오에서 약간의 개선이 있었습니다. 예를 들어 Conc=16에서 1238.56ms에서 1179.68ms로 약 4.7% 감소했습니다. Conc=4에서는 1058.05ms에서 1052.65ms로 소폭 감소했습니다.
  • P50 TPOT (Time Per Output Token): 대부분의 시나리오에서 약간의 감소를 보였습니다. 예를 들어 Conc=1에서 12.27ms에서 12.01ms로 약 2.1% 감소했습니다.

이러한 수치들은 특히 저지연 시나리오에서 융합 커널의 이점을 명확히 보여줍니다.

일반적인 교훈: LLM 추론 성능을 최적화할 때, 연산 그래프를 분석하여 커널 융합이 가능한 지점을 찾는 것이 중요합니다. 특히, 여러 연산이 순차적으로 실행되고 중간 결과 저장을 위해 메모리 접근이 빈번한 경우, 이를 단일 커널로 통합하면 지연 시간과 메모리 대역폭 사용량을 크게 개선할 수 있습니다. 또한, 특정 모델 아키텍처의 특성을 활용한 맞춤형 커널 개발은 범용 최적화보다 더 큰 성능 향상을 가져올 수 있습니다.

References

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글