본문으로 건너뛰기

[vllm] vLLM, Gemma4 라우팅 함수 Triton 커널로 최적화하여 성능 대폭 향상

PR 링크: vllm-project/vllm#39083 상태: Merged | 변경: +None / -None

들어가며

대규모 언어 모델(LLM) 서빙에서 성능 최적화는 사용자 경험과 직결되는 매우 중요한 요소입니다. 특히 Mixture-of-Experts (MoE) 모델의 경우, 여러 전문가(expert) 중 일부만 선택하여 연산을 수행하는 라우팅(routing) 과정이 전체 성능에 큰 영향을 미칩니다. 최근 vLLM 프로젝트에서는 Gemma4 모델의 라우팅 함수를 최적화하여 서빙 성능을 대폭 향상시키는 PR이 병합되었습니다. 이 글에서는 해당 PR의 코드 변경 사항을 분석하고, 왜 이러한 최적화가 효과적인지, 그리고 어떤 기술적 교훈을 얻을 수 있는지 살펴보겠습니다.

기존 Gemma4 모델의 라우팅 함수는 PyTorch 연산을 기반으로 구현되어 있었습니다. 이 방식은 몇 가지 성능 저하 요인을 가지고 있었습니다. 첫째, PyTorch 연산은 여러 동기화 지점(synchronization points)을 발생시키고 전역 메모리(global memory)에 대한 빈번한 읽기/쓰기를 유발하여 오버헤드가 컸습니다. 둘째, 이러한 PyTorch 연산은 torch.compile과 같은 최신 컴파일러 최적화 기법의 대상에서 제외되어, 최신 컴파일러 기술을 활용한 성능 향상을 온전히 누리지 못했습니다.

본 PR은 이러한 문제를 해결하기 위해 Gemma4 모델에 특화된 라우팅 함수를 NVIDIA Triton을 사용하여 커스텀 커널로 구현했습니다. Triton은 GPU 커널을 더 쉽게 작성하고 최적화할 수 있도록 돕는 도구로, PyTorch 연산 대비 훨씬 효율적인 연산이 가능합니다.

코드 분석

1. pyproject.toml 변경

diff --git a/pyproject.toml b/pyproject.toml
index f55dd9308bd5..5c87de018c10 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -170,6 +170,7 @@ eles = "eles"
 datas = "datas"
 ser = "ser"
 ure = "ure"
+VALU = "VALU"
 # Walsh-Hadamard Transform
 wht = "wht"
 WHT = "WHT"

pyproject.toml 파일에 VALU라는 새로운 상수가 추가되었습니다. 이는 Triton 커널 내부에서 사용될 수 있는 상수 값으로, 코드의 가독성과 유지보수성을 높이는 데 기여합니다. 특별한 성능 향상과는 직접적인 관련은 없지만, 커널 구현의 일부로써 의미를 가집니다.

2. tests/kernels/moe/test_gemma4router.py 신규 추가

diff --git a/tests/kernels/moe/test_gemma4router.py b/tests/kernels/moe/test_gemma4router.py
new file mode 100644
index 000000000000..ba69d6927495
--- /dev/null
+++ b/tests/kernels/moe/test_gemma4router.py
@@ -0,0 +1,57 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import pytest
+import torch
+
+from vllm.model_executor.models.gemma4 import (
+    gemma4_fused_routing_kernel_triton,
+    gemma4_routing_function_torch,
+)
+
+
+def sort_by_id(w, ids):
+    order = ids.argsort(dim=-1)
+    return w.gather(1, order), ids.gather(1, order)
+
+
+# Gemma4 MoE Model has context length of 250K
+# the minus 1 is to ensure that edge cases are tested
+@pytest.mark.parametrize("num_tokens", [1, 2, 2048, 250000])
+@pytest.mark.parametrize("num_experts", [128])  # gemma4 moe experts
+@pytest.mark.parametrize("topk", [8])  # gemma4 topk
+@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
+def test_gemma4_routing_kernel_triton(
+    num_tokens: int,
+    num_experts: int,
+    topk: int,
+    dtype: torch.dtype,
+):
+    torch.manual_seed(0)
+
+    gating = torch.randn(num_tokens, num_experts, dtype=dtype, device="cuda")
+    scales = torch.rand(num_experts, dtype=torch.float32, device="cuda")
+
+    ref_w, ref_ids = gemma4_routing_function_torch(gating, topk, scales)
+    tri_w, tri_ids = gemma4_fused_routing_kernel_triton(gating, topk, scales)
+
+    # Sort by expert id — to remove tie-breaking differences
+    ref_ws, ref_is = sort_by_id(ref_w, ref_ids)
+    tri_ws, tri_is = sort_by_id(tri_w, tri_ids)
+
+    ids_match = (ref_is == tri_is).all().item()
+    weights_match = torch.allclose(ref_ws, tri_ws, atol=1e-2, rtol=1e-2)
+    all_match = ids_match and weights_match
+    max_err = (ref_ws - tri_ws).abs().max().item()
+    print(
+        f"T={num_tokens:5d} E={num_experts:4d} K={topk} "
+        f"{str(dtype).split('.')[-1]:7s} ids={ids_match} max_Δweight={max_err:.2e}"
+    )
+    if not all_match:
+        bad = (ref_is != tri_is).any(dim=-1).nonzero(as_tuple=True)[0]
+        if len(bad):
+            r = bad[0].item()
+            print(
+                f"  first bad row {r}: ref_ids={ref_ids[r].tolist()} "
+                f"tri_ids={tri_ids[r].tolist()}"
+            )
+        assert all_match

이 파일은 새로운 Triton 커널의 정확성을 검증하기 위한 테스트 코드를 포함합니다. gemma4_routing_function_torch (기존 PyTorch 구현)와 gemma4_fused_routing_kernel_triton (새로운 Triton 구현)의 결과를 비교하여 동일한 출력을 내는지 확인합니다. 다양한 num_tokens, num_experts, topk, dtype 조합에 대해 테스트를 수행하며, 결과의 일치 여부와 최대 오차를 출력합니다. 이는 새로운 커널이 기존 로직과 동일하게 동작함을 보장하는 중요한 단계입니다.

3. vllm/model_executor/models/gemma4.py 핵심 변경

3.1. Triton 커널 정의 (_gemma4_routing_kernel, gemma4_fused_routing_kernel_triton)

diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py
index 06189540090d..d166a9df38ac 100644
--- a/vllm/model_executor/models/gemma4.py
+++ b/vllm/model_executor/models/gemma4.py
@@ -57,7 +57,9 @@
     default_weight_loader,
     maybe_remap_kv_scale_name,
 )
+from vllm.platforms import current_platform
 from vllm.sequence import IntermediateTensors
+from vllm.triton_utils import tl, triton
 from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
 
 from .interfaces import (
@@ -79,6 +81,120 @@
 logger = init_logger(__name__)
 
 
+@triton.jit
+def _gemma4_routing_kernel(
+    gating_ptr,
+    per_expert_scale_ptr,
+    topk_weights_ptr,
+    topk_ids_ptr,
+    E: tl.constexpr,
+    K: tl.constexpr,
+    BLOCK_E: tl.constexpr,
+):
+    pid = tl.program_id(0)
+    offs_e = tl.arange(0, BLOCK_E)
+    valid = offs_e < E
+
+    logits = tl.load(
+        gating_ptr + pid * E + offs_e,
+        mask=valid,
+        other=-float("inf"),
+    ).to(tl.float32)
+
+    max_l = tl.max(logits, axis=0)
+
+    # Float32 → ascending-sortable bijection
+    MIN32 = -2147483648
+    logit_bits = logits.to(tl.int32, bitcast=True)
+    sign_b = logit_bits >> 31
+    key = tl.where(sign_b == 0, logit_bits ^ -1, logit_bits ^ MIN32)
+    key = tl.where(valid, key, 0x7FFFFFFF)
+    sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF
+    packed = (sk64 << 32) | offs_e.to(tl.int64)
+    sorted_p = tl.sort(packed, descending=False)
+
+    # Vectorized extraction of ALL sorted elements — no K-loop, no cross-lane reductions
+    all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32)
+    all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32)
+
+    # Inverse bijection: recover original logit bits
+    sign_k = all_keys >> 31
+    all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32)
+    all_logits = all_bits.to(tl.float32, bitcast=True)
+
+    # Compute raw_exp for ALL BLOCK_E elements — vectorized, ~2 VALU clocks
+    all_raw_exp = tl.math.exp2((all_logits - max_l) * 1.4426950408889634)
+
+    # Sum only top-K for renorm — ONE masked reduction
+    top_mask = offs_e < K
+    renorm_raw = tl.sum(tl.where(top_mask, all_raw_exp, 0.0), axis=0)
+    renorm_raw = tl.where(renorm_raw > 0.0, renorm_raw, 1.0)
+    inv_renorm = 1.0 / renorm_raw
+
+    # Load scales for top-K only (masked gather; scale array is tiny → L1 cached)
+    all_scales = tl.load(
+        per_expert_scale_ptr + all_ids.to(tl.int64),
+        mask=top_mask,
+        other=1.0,
+    ).to(tl.float32)
+
+    # Final weights: vectorized multiply (only top-K will be stored)
+    all_weights = (all_raw_exp * inv_renorm * all_scales).to(tl.float32)
+
+    # Write results with TWO masked stores — replaces K × 2 serial scalar stores
+    base_off = pid * K + offs_e
+    tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask)
+    tl.store(topk_weights_ptr + base_off, all_weights, mask=top_mask)
+    
+
+def gemma4_fused_routing_kernel_triton(
+    gating_output: torch.Tensor,
+    topk: int,
+    per_expert_scale: torch.Tensor,
+    num_warps: int = 1,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    gating_output = gating_output.contiguous()
+    per_expert_scale = per_expert_scale.contiguous()
+    T, E = gating_output.shape
+    weights = torch.empty(T, topk, dtype=torch.float32, device=gating_output.device)
+    ids = torch.empty(T, topk, dtype=torch.int32, device=gating_output.device)
+    BLOCK_E = triton.next_power_of_2(E)
+    _gemma4_routing_kernel[(T,)](
+        gating_output,
+        per_expert_scale,
+        weights,
+        ids,
+        E,
+        topk,
+        BLOCK_E,
+        num_warps=num_warps,
+    )
+    return weights, ids
+
+
def gemma4_routing_function_torch(
+    gating_output: torch.Tensor,
+    topk: int,
+    per_expert_scale: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    _, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
+    router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1)
+    indicator = torch.nn.functional.one_hot(
+        topk_ids, num_classes=gating_output.size(-1)
+    ).sum(dim=-2)
+    gate_weights = indicator * router_probabilities
+    renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True)
+    renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0)
+    dispatch_weights = gate_weights / renorm_factor
+
+    topk_weights = dispatch_weights.gather(1, topk_ids)
+
+    # Fold per_expert_scale into routing weights
+    expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype)
+    topk_weights = topk_weights * expert_scales
+    return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
+
 
 def _get_text_config(config):
     """Dereference text_config if config is a nested Gemma4Config.
@@ -216,22 +332,12 @@ def routing_function(
             topk: int,
             renormalize: bool,
         ) -> tuple[torch.Tensor, torch.Tensor]:
-            _, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
-            router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1)
-            indicator = torch.nn.functional.one_hot(
-                topk_ids, num_classes=gating_output.size(-1)
-            ).sum(dim=-2)
-            gate_weights = indicator * router_probabilities
-            renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True)
-            renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0)
-            dispatch_weights = gate_weights / renorm_factor
-
-            topk_weights = dispatch_weights.gather(1, topk_ids)
-
-            # Fold per_expert_scale into routing weights
-            expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype)
-            topk_weights = topk_weights * expert_scales
-            return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
+            if current_platform.is_cuda_alike() or current_platform.is_xpu():
+                return gemma4_fused_routing_kernel_triton(
+                    gating_output, topk, per_expert_scale
+                )
+
+            return gemma4_routing_function_torch(gating_output, topk, per_expert_scale)
 
         # FusedMoE experts with custom Gemma4 routing
         self.experts = FusedMoE( 
  • _gemma4_routing_kernel: Triton 언어를 사용하여 구현된 실제 커널 함수입니다. 이 함수는 입력 gating_output에서 상위 K개의 전문가를 선택하고, 각 전문가에 대한 가중치와 인덱스를 계산합니다. Triton의 강력한 병렬 처리 기능을 활용하여 메모리 접근을 최소화하고 연산 속도를 극대화하도록 설계되었습니다. 특히, Float32 → ascending-sortable bijection 부분은 부동 소수점 값을 정렬 가능한 정수 형태로 변환하여 효율적인 정렬을 수행하는 기법입니다. 또한, Vectorized extraction of ALL sorted elementsONE masked reduction 등은 병렬 처리의 이점을 극대화합니다.
  • gemma4_fused_routing_kernel_triton: 위 Triton 커널을 호출하는 파이썬 래퍼 함수입니다. 입력 텐서를 정리하고, Triton 커널 실행에 필요한 파라미터를 설정하며, 결과를 반환합니다. BLOCK_Enum_warps 같은 파라미터는 Triton 커널의 성능 튜닝에 사용됩니다.
  • routing_function 수정: 기존의 routing_function은 이제 current_platform.is_cuda_alike() or current_platform.is_xpu() 조건에 따라 Triton 커널 (gemma4_fused_routing_kernel_triton) 또는 기존 PyTorch 구현 (gemma4_routing_function_torch)을 선택적으로 호출하도록 변경되었습니다. 이는 CUDA 또는 XPU 환경에서는 최적화된 Triton 커널을 사용하고, 그렇지 않은 환경에서는 기존의 안정적인 PyTorch 구현을 사용하도록 하여 호환성을 유지합니다.

기존 PyTorch 구현(gemma4_routing_function_torch)은 다음과 같습니다:

def gemma4_routing_function_torch(
    gating_output: torch.Tensor,
    topk: int,
    per_expert_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    _, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
    router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1)
    indicator = torch.nn.functional.one_hot(
        topk_ids, num_classes=gating_output.size(-1)
    ).sum(dim=-2)
    gate_weights = indicator * router_probabilities
    renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True)
    renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0)
    dispatch_weights = gate_weights / renorm_factor

    topk_weights = dispatch_weights.gather(1, topk_ids)

    # Fold per_expert_scale into routing weights
    expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype)
    topk_weights = topk_weights * expert_scales
    return topk_weights.to(torch.float32), topk_ids.to(torch.int32)

이 코드는 torch.topk, torch.nn.functional.softmax, torch.nn.functional.one_hot 등 여러 PyTorch 연산을 순차적으로 호출합니다. 각 연산은 GPU 커널 간의 동기화 및 메모리 전송을 유발할 수 있어 성능 병목의 원인이 됩니다.

4. 리뷰 댓글 반영

리뷰 과정에서 torch.compile의 효과에 대한 논의가 있었습니다. [ProExpertProg]님은 Triton 커널 대신 torch.compile을 적용하는 방안을 제안했지만, [tjtanaa]님은 마이크로벤치마크 결과 Triton 커널이 torch.compile보다 훨씬 빠르다는 것을 보여주었습니다.

마이크로벤치마크 결과 (A100 기준, vs Eager 대비):

GPU vs Eager vs compile-default vs compile-reduce-overhead vs compile-max-autotune
A100 12.68× 4.50× 4.97× 4.45×

이 결과는 torch.compile이 기존 PyTorch 코드 대비 성능을 향상시키기는 하지만, Triton으로 직접 구현한 커널의 성능에는 미치지 못함을 명확히 보여줍니다. 이는 복잡하고 성능에 민감한 연산의 경우, 직접 커널을 작성하는 것이 여전히 가장 효과적인 최적화 방법일 수 있음을 시사합니다.

또한, [tjtanaa]님은 Triton 커널의 num_warps를 1로 설정하는 것이 성능에 가장 좋다는 경험적 사실을 공유했습니다. 이는 Triton 커널 튜닝의 중요성을 보여주는 예시입니다.

왜 이게 좋은가?

성능 향상

이 PR의 가장 큰 장점은 Gemma4 모델의 서빙 성능을 크게 향상시켰다는 점입니다. 제공된 벤치마크 결과는 이를 명확하게 입증합니다.

전체 처리량 (Total Throughput) 향상:

GPU Total Throughput (tok/s)
A100 +8.6% (3,419 → 3,689)
MI300X +21.4% (5,252 → 6,382)
H100 +4.3% (5,813 → 6,060)
B60 +32.0% (477 → 629)

특히 MI300X와 B60 GPU에서 20% 이상의 상당한 성능 향상을 보였습니다. 이는 라우팅 함수의 병목 현상이 해소되었음을 의미합니다.

지연 시간 감소:

GPU TPOT (ms)
A100 -8.5% (20.98 → 19.25)
MI300X -18.9% (14.07 → 11.41)
H100 -4.4% (12.73 → 12.17)
B60 -26.0% (157.1 → 116.2)

TPOT (Time Per Output Token) 지연 시간 또한 모든 GPU에서 감소하여, 응답 속도가 빨라졌음을 알 수 있습니다.

기술적 교훈

  1. 커스텀 커널의 중요성: 복잡하고 성능에 민감한 연산의 경우, PyTorch의 기본 연산이나 torch.compile만으로는 충분한 성능을 얻기 어려울 수 있습니다. Triton과 같은 도구를 사용하여 직접 GPU 커널을 작성하는 것이 때로는 최고의 성능 향상 방법입니다. 특히 MoE 모델의 라우팅과 같이 모델 아키텍처에 특화된 연산에서 효과적입니다.
  2. 메모리 접근 최소화 및 병렬 처리 극대화: Triton 커널은 전역 메모리 접근을 최소화하고, 데이터를 레지스터나 공유 메모리에 최대한 활용하며, 병렬 처리 능력을 극대화하도록 설계되었습니다. 이는 기존 PyTorch 구현의 동기화 오버헤드와 메모리 접근 문제를 해결하는 핵심입니다.
  3. 플랫폼별 최적화: current_platform을 사용하여 CUDA 및 XPU 환경에서 최적화된 Triton 커널을 적용하고, 그렇지 않은 환경에서는 기존 PyTorch 구현을 사용하는 방식은 코드의 재사용성과 호환성을 높이는 좋은 전략입니다.
  4. 정확성 검증의 중요성: 새로운 커널 구현 시, 기존 구현과의 정확성 차이를 엄격하게 테스트하는 것이 필수적입니다. test_gemma4router.py 파일의 테스트 케이스들은 이를 잘 보여줍니다. 특히 atol=1e-2, rtol=1e-2와 같은 허용 오차 설정은 부동 소수점 연산의 특성을 고려한 것입니다.

결론

vLLM의 Gemma4 라우팅 함수 Triton 커널 도입은 LLM 서빙 성능 최적화의 좋은 사례를 보여줍니다. 커스텀 커널 개발을 통해 기존의 PyTorch 기반 구현이 가진 성능 병목을 효과적으로 해결했으며, 이는 다양한 GPU 환경에서 실질적인 성능 향상으로 이어졌습니다. 앞으로도 vLLM은 이러한 적극적인 최적화를 통해 LLM 서빙의 한계를 넓혀나갈 것으로 기대됩니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글