본문으로 건너뛰기

[vllm] vLLM ROCm 환경에서 AITER를 활용한 Multi-Head Convolutions(MHC) 성능 최적화 및 안정성 개선

PR 링크: vllm-project/vllm#41946 상태: Merged | 변경: +0 / -0

들어가며

최근 대규모 언어 모델(LLM)의 발전 속도는 눈부시지만, 이를 효율적으로 서빙하기 위한 최적화 노력 또한 그에 못지않게 중요합니다. 특히 NVIDIA GPU가 아닌 AMD GPU 환경에서의 LLM 서빙 성능은 중요한 과제 중 하나입니다. vLLM은 뛰어난 성능으로 LLM 추론 엔진의 표준으로 자리 잡고 있으며, 이번 PR은 vLLM이 AMD GPU(ROCm 환경)에서 Multi-Head Convolutions(MHC) 연산의 성능을 최적화하고, 특정 라이브러리 의존성으로 인한 안정성 문제를 해결하는 데 중점을 두고 있습니다.

이 PR은 크게 두 가지 목표를 달성합니다:

  1. ROCm 환경에서의 AITER MHC 커널 통합 및 성능 최적화: 최신 AITER 라이브러리를 활용하여 MHC 연산(mhc-pre, mhc-post, hc-head)의 성능을 개선합니다.
  2. Tilelang 의존성 제거 및 안정성 향상: Tilelang 라이브러리가 설치되지 않은 환경에서도 vLLM이 정상적으로 동작하도록 수정하여, ROCm 환경에서의 호환성 및 안정성을 높입니다.

이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 이러한 변경이 왜 성능 향상과 안정성 개선으로 이어지는지, 그리고 기술적인 교훈은 무엇인지 살펴보겠습니다.

코드 분석

1. requirements/rocm.txt 변경

Before:

--- a/requirements/rocm.txt
+++ b/requirements/rocm.txt
@@ -21,6 +21,3 @@ timm>=1.0.17
 # amd-quark: required for Quark quantization on ROCm 
 # To be consistent with test_quark.py
 amd-quark>=0.8.99
-# tilelang has to be installed for mhc module to be
-# imported correctly.
-tilelang==0.1.9

After: (위 diff는 제거된 라인을 보여줍니다.)

requirements/rocm.txt 파일에서 tilelang==0.1.9 라인이 제거되었습니다. 이는 더 이상 ROCm 환경에서 vLLM을 빌드하거나 실행할 때 tilelang 라이브러리가 필수적으로 요구되지 않음을 의미합니다. 이전에는 MHC 모듈이 올바르게 임포트되기 위해 tilelang 설치가 필요했지만, 이 PR을 통해 그 의존성이 제거되었습니다.

2. tests/kernels/test_mhc_kernels.py 변경

Before:

--- a/tests/kernels/test_mhc_kernels.py
+++ b/tests/kernels/test_mhc_kernels.py
@@ -3,7 +3,7 @@
 import pytest
 import torch
 
-import vllm.model_executor.layers.mhc as mhc_ops  # noqa: F401
+import vllm.model_executor.kernels.mhc  # noqa: F401
 from vllm.platforms import current_platform
 from vllm.utils.torch_utils import set_random_seed
 

After:

@@ -121,7 +121,7 @@ def run_ref():
 
     residual_ref, post_mix_ref, res_mix_ref, layer_input_ref = run_ref()
 
-    residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre(
+    residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre_tilelang(
         x,
         residual,
         post_layer_mix,
@@ -140,3 +140,42 @@ def run_ref():
     torch.testing.assert_close(post_mix, post_mix_ref, atol=1e-2, rtol=1e-2)
     torch.testing.assert_close(res_mix, res_mix_ref, atol=1e-2, rtol=1e-2)
     torch.testing.assert_close(x, layer_input_ref, atol=1e-2, rtol=1e-2)
+
+@pytest.mark.skipif(
+    not current_platform.is_rocm(),
+    reason="ROCm required",
+)
+@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
+@pytest.mark.parametrize("hidden_size", [4096, 7168])
+@pytest.mark.parametrize("hc_mult", [4])
+def test_hc_head_triton(num_tokens, hidden_size, hc_mult):
+    torch.set_default_device(DEVICE)
+    set_random_seed(0)
+
+    residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
+    fn = torch.randn((hc_mult, hc_mult * hidden_size), dtype=torch.float32) * 1e-4
+    hc_scale = torch.randn((1,), dtype=torch.float32) * 0.1
+    hc_base = torch.randn((hc_mult,), dtype=torch.float32) * 0.1
+    rms_eps = hc_eps = 1e-6
+
+    out = torch.empty((num_tokens, hidden_size), dtype=torch.bfloat16)
+    out.fill_(float("nan"))
+
+    result = torch.ops.vllm.hc_head_triton(
+        residual,
+        fn,
+        hc_scale,
+        hc_base,
+        out,
+        hidden_size,
+        rms_eps,
+        hc_eps,
+        hc_mult,
+    )
+
+    assert result is None
+    assert not torch.isnan(out).any()
+
+    out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps)
+    torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2)
  • import vllm.model_executor.layers.mhc as mhc_opsimport vllm.model_executor.kernels.mhc로 변경되었습니다. 이는 MHC 관련 연산들이 더 이상 layers 모듈에 속하지 않고, kernels 모듈로 재구성되었음을 나타냅니다. 이는 코드의 구조적인 개선 및 관심사의 분리를 보여줍니다.
  • torch.ops.vllm.mhc_fused_post_pre 호출이 torch.ops.vllm.mhc_fused_post_pre_tilelang으로 변경되었습니다. 이는 해당 연산이 이제 Tilelang 구현을 사용하도록 명시한 것입니다.
  • 새로운 test_hc_head_triton 함수가 추가되었습니다. 이 테스트는 ROCm 환경에서 hc_head 연산이 Triton 커널을 통해 올바르게 동작하는지 검증합니다. 이는 hc_head 연산의 ROCm 지원 및 성능 개선을 위한 핵심적인 부분입니다.

3. vllm/_aiter_ops.py 변경

Before: (이 파일의 해당 부분은 PR 설명에 직접적으로 나타나지 않지만, 기존에는 AITER 관련 로직이 없었거나 다른 방식으로 구현되었을 것으로 추정됩니다.)

After:

@@ -2395,5 +2395,183 @@ def paged_attention_common(
             kv_cache_dtype=kv_cache_dtype,
         )
 
+    @staticmethod
+    def mhc_pre(
+        residual: torch.Tensor,
+        fn: torch.Tensor,
+        hc_scale: torch.Tensor,
+        hc_base: torch.Tensor,
+        rms_eps: float,
+        hc_pre_eps: float,
+        hc_sinkhorn_eps: float,
+        hc_post_mult_value: float,
+        sinkhorn_repeat: int,
+    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """
+        Forward pass for mHC pre block.
+
+        Args:
+            residual: shape (..., hc_mult, hidden_size), dtype torch.bfloat16
+            fn: shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32
+            hc_scale: shape (3,), dtype torch.float32
+            hc_base: shape (hc_mult3,), dtype torch.float32
+            rms_eps: RMS normalization epsilon
+            hc_pre_eps: pre-mix epsilon
+            hc_sinkhorn_eps: sinkhorn epsilon
+            hc_post_mult_value: post-mix multiplier value
+            sinkhorn_repeat: number of sinkhorn iterations
+            n_splits: split-k factor;
+
+        Returns:
+            post_mix: shape (..., hc_mult), dtype torch.float32
+            comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32
+            layer_input: shape (..., hidden_size), dtype torch.bfloat16
+        """
+        from aiter.ops.mhc import mhc_pre
+
+        # Validate shapes
+        assert residual.dtype == torch.bfloat16
+        assert fn.dtype == torch.float32
+        assert hc_scale.dtype == torch.float32
+        assert hc_base.dtype == torch.float32
+
+        hc_mult = residual.shape[-2]
+        hidden_size = residual.shape[-1]
+        hc_mult2 = hc_mult * hc_mult
+        hc_mult3 = hc_mult * 2 + hc_mult2
+
+        hc_hidden_size = hc_mult * hidden_size
+        assert fn.shape[0] == hc_mult3
+        assert fn.shape[1] == hc_hidden_size
+        assert hc_scale.shape == (3,)
+        assert hc_base.shape == (hc_mult3,)
+
+        outer_shape = residual.shape[:-2]
+
+        residual_flat = residual.view(-1, hc_mult, hidden_size)
+
+        num_tokens = residual_flat.shape[0]
+        if num_tokens == 0:
+            return (
+                torch.empty(
+                    num_tokens,
+                    hc_mult,
+                    1,
+                    dtype=torch.float32,
+                    device=residual_flat.device,
+                ),
+                torch.empty(
+                    num_tokens,
+                    hc_mult,
+                    hc_mult,
+                    dtype=torch.float32,
+                    device=residual_flat.device,
+                ),
+                torch.empty(
+                    num_tokens,
+                    hidden_size,
+                    dtype=torch.bfloat16,
+                    device=residual_flat.device,
+                ),
+            )
+
+        # AITER's Python wrapper allocates intermediate/output tensors without
+        # explicit device arguments, so run it under the residual tensor's device.
+        with torch.device(residual_flat.device):
+            post_mix, comb_mix, layer_input = mhc_pre(
+                residual_flat,
+                fn,
+                hc_scale,
+                hc_base,
+                rms_eps,
+                hc_pre_eps,
+                hc_sinkhorn_eps,
+                hc_post_mult_value,
+                sinkhorn_repeat,
+            )
+        return (
+            post_mix.view(*outer_shape, hc_mult, 1),
+            comb_mix.view(*outer_shape, hc_mult, hc_mult),
+            layer_input.view(*outer_shape, hidden_size),
+        )
+
+    @staticmethod
+    def hc_head(
+        hs_flat: torch.Tensor,
+        fn: torch.Tensor,
+        hc_scale: torch.Tensor,
+        hc_base: torch.Tensor,
+        out: torch.Tensor,
+        hidden_size: int,
+        rms_eps: float,
+        hc_eps: float,
+        hc_mult: int,
+    ) -> None:
+        """Run hc_head through AITER mhc_pre and write the result to out."""
+        assert hs_flat.dtype == torch.bfloat16
+        assert fn.dtype == torch.float32
+        assert hc_scale.dtype == torch.float32
+        assert hc_base.dtype == torch.float32
+        assert hs_flat.shape[-2:] == (hc_mult, hidden_size)
+        assert fn.shape == (hc_mult, hc_mult * hidden_size)
+        assert hc_scale.shape == (1,)
+        assert hc_base.shape == (hc_mult,)
+
+        num_tokens = hs_flat.shape[0]
+        if num_tokens == 0:
+            return
+
+        hc_mult3 = hc_mult * 2 + hc_mult * hc_mult
+
+        full_fn = torch.zeros(
+            hc_mult3,
+            hc_mult * hidden_size,
+            dtype=fn.dtype,
+            device=fn.device,
+        )
+        full_fn[:hc_mult] = fn
+
+        full_base = torch.zeros(hc_mult3, dtype=hc_base.dtype, device=hc_base.device)
+        full_base[:hc_mult] = hc_base
+
+        full_scale = torch.zeros(3, dtype=hc_scale.dtype, device=hc_scale.device)
+        full_scale[0] = hc_scale[0]
+
+        _, _, layer_input = rocm_aiter_ops.mhc_pre(
+            hs_flat,
+            full_fn,
+            full_scale,
+            full_base,
+            rms_eps,
+            hc_eps,
+            0.0,
+            1.0,
+            0,
+        )
+        out.copy_(layer_input)
+
+    @staticmethod
+    def mhc_post(
+        x: torch.Tensor,
+        residual: torch.Tensor,
+        post_layer_mix: torch.Tensor,
+        comb_res_mix: torch.Tensor,
+    ) -> torch.Tensor:
+        from aiter.ops.mhc import mhc_post
+
+        hc_mult = residual.shape[-2]
+        hidden_size = residual.shape[-1]
+        residual_flat = residual.view(-1, hc_mult, hidden_size)
+        num_tokens = residual_flat.shape[0]
+        out = torch.empty_like(residual_flat)
+        mhc_post(
+            out,
+            x.view(num_tokens, hidden_size),
+            residual_flat,
+            post_layer_mix.view(num_tokens, hc_mult, 1),
+            comb_res_mix.view(num_tokens, hc_mult, hc_mult),
+        )
+        return out.view_as(residual)
+ 
 
 rocm_aiter_ops.register_ops_once()

이 파일은 ROCm 환경에서 AITER 라이브러리의 연산을 vLLM의 연산으로 래핑하는 역할을 합니다. 이번 PR에서는 mhc_pre, hc_head, mhc_post와 같은 MHC 관련 연산들이 AITER 라이브러리의 함수를 호출하도록 추가되었습니다.

  • mhc_pre 함수는 AITER의 aiter.ops.mhc.mhc_pre를 호출하며, 입력 텐서의 유효성을 검사하고 최종 결과를 반환합니다. 특히, AITER의 Python 래퍼가 명시적인 장치 인자 없이 텐서를 할당하는 문제를 해결하기 위해 with torch.device(residual_flat.device): 컨텍스트 내에서 실행됩니다. 이는 리뷰 댓글에서 지적된 OOM(Out Of Memory) 문제를 방지하기 위한 중요한 조치입니다.
  • hc_head 함수는 mhc_pre를 사용하여 hc_head 연산을 수행하고 결과를 out 텐서에 복사합니다. 이는 hc_head 연산이 AITER의 mhc_pre 연산을 재활용하여 구현되었음을 보여줍니다. 리뷰어 tjtanaa는 이 함수가 hc_head_op으로 명명되어야 한다고 제안했지만, 여기서는 hc_head로 유지되었습니다. (이 부분은 PR 설명과 리뷰 댓글 간의 미묘한 차이가 있을 수 있습니다.)
  • mhc_post 함수 역시 AITER의 aiter.ops.mhc.mhc_post를 호출하여 MHC 후처리 연산을 수행합니다.

이러한 AITER 커널의 통합은 ROCm 환경에서 MHC 연산의 성능을 크게 향상시킬 것으로 기대됩니다. AITER는 AMD GPU에 최적화된 연산을 제공하기 때문입니다.

4. vllm/_tilelang_ops.py 신규 생성

Before: (이 파일은 존재하지 않았습니다.)

After:

--- /dev/null
+++ b/vllm/_tilelang_ops.py
@@ -0,0 +1,462 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import math
+from functools import cache
+from typing import TYPE_CHECKING
+
+import torch
+
+from vllm.platforms import current_platform
+from vllm.utils.import_utils import has_tilelang
+from vllm.utils.math_utils import cdiv
+
+# tilelang is only available on CUDA platforms
+if TYPE_CHECKING or current_platform.is_cuda():
+    if not has_tilelang():
+        raise ImportError(
+            

## 참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html

> ⚠️ **알림:** 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글