본문으로 건너뛰기

[vllm] vLLM Mamba2 SSD 커널 웜업: 첫 요청 지연 시간 91% 감소의 비결

PR 링크: vllm-project/vllm#39822 상태: Merged | 변경: +0 / -0

들어가며

대규모 언어 모델(LLM)을 서빙할 때, 첫 번째 추론 요청의 지연 시간은 사용자 경험에 큰 영향을 미칩니다. 특히 Mamba2와 같은 새로운 아키텍처는 Triton과 같은 고성능 컴퓨팅 라이브러리를 활용하여 최적화된 커널을 사용하는데, 이 커널들이 '게으르게(lazily)' 동작하여 첫 요청 시점에 컴파일 및 튜닝되는 경우가 많습니다. 이는 첫 요청에서 상당한 지연 시간 스파이크를 유발하며, 실제 서비스 환경에서는 치명적인 단점이 될 수 있습니다.

vLLM 프로젝트의 이 PR은 이러한 문제를 해결하기 위해 Mamba2 SSD(State Space Model) 커널의 Triton 자동 튜닝을 서버 시작 단계로 옮겨, 첫 요청 지연 시간을 획기적으로 줄이는 최적화를 제안합니다. 이 글에서는 해당 PR의 코드 변경 사항을 분석하고, 이 최적화가 왜 중요하며 어떤 기술적 교훈을 주는지 살펴보겠습니다.

코드 분석: Mamba2 SSD 커널 웜업

이 PR의 핵심은 Mamba2 모델의 MambaMixer2 클래스에 _warmup_ssd_kernels() 메서드를 추가하고, vLLM의 프로파일링 단계에서 이 메서드를 호출하여 Triton 커널 자동 튜닝을 미리 수행하는 것입니다. 또한, ModelConfigget_mamba_chunk_size 메서드에서 Mamba1의 기본 청크 사이즈를 1024에서 2048로 변경하는 소소한 개선도 포함되어 있습니다.

vllm/config/model.py 변경 사항

이 파일에서는 ModelConfig 클래스의 get_mamba_chunk_size 메서드가 수정되었습니다. Mamba1 모델의 경우 명시적인 chunk_size가 없을 때 사용되는 기본값이 변경되었습니다.

Before:

    def get_mamba_chunk_size(self) -> int | None:
        """
        Returns the mamba chunk size if it exists
        """
        if self.hf_text_config is not None:
            chunk_size = getattr(self.hf_text_config, "chunk_size", None)

        # Since Mamba1 does not have a chunk notion
        # we use a default chunk size of 1024.
        if chunk_size is None:
            chunk_size = 1024

        return chunk_size

After:

    def get_mamba_chunk_size(self) -> int:
        """
        Returns the mamba chunk size if it exists
        """
        if self.hf_text_config is not None:
            chunk_size = getattr(self.hf_text_config, "chunk_size", None)

        # Since Mamba1 does not have a chunk notion
        # we use a default chunk size of 2048.
        if chunk_size is None:
            chunk_size = 2048

        return chunk_size
  • 무엇이 왜 좋은가: get_mamba_chunk_size의 반환 타입 힌트가 int | None에서 int로 변경되어, 항상 int를 반환함을 명확히 했습니다. 또한, Mamba1의 기본 chunk_size가 1024에서 2048로 변경되었습니다. 이는 Mamba2 모델의 웜업 시 사용되는 chunk_size와 일관성을 유지하고, 더 큰 청크 사이즈가 특정 시나리오에서 성능 이점을 제공할 수 있기 때문입니다. 리뷰어 tomeras91의 지적처럼, 문서와 코드의 불일치를 수정하고 더 합리적인 기본값을 설정한 것입니다.

vllm/model_executor/layers/mamba/gdn_linear_attn.py 변경 사항

이 파일에서는 GatedDeltaRule 클래스에 _prefill_kernels_warmed_up 플래그를 초기화하는 방식이 개선되었습니다. 이는 MambaMixer2의 웜업 로직과 일관성을 맞추기 위한 변경입니다.

Before:

    def _warmup_prefill_kernels(self, qkv_or_qkvz: torch.Tensor, v_dim: int) -> None:
        # ...
        if hasattr(self, "_prefill_kernels_warmed_up"):
            return
        self._prefill_kernels_warmed_up = True
        # ...

After:

    def __init__(
        # ...
    ):
        # ...
        self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
        self._prefill_kernels_warmed_up = False
        # ...

    def _warmup_prefill_kernels(self, qkv_or_qkvz: torch.Tensor, v_dim: int) -> None:
        # ...
        if self._prefill_kernels_warmed_up:
            return
        self._prefill_kernels_warmed_up = True
        # ...
  • 무엇이 왜 좋은가: _prefill_kernels_warmed_up 속성을 __init__ 메서드에서 명시적으로 False로 초기화하도록 변경했습니다. 이전에는 hasattr를 통해 속성 존재 여부를 확인했지만, 명시적인 초기화는 코드의 가독성과 예측 가능성을 높입니다. 이는 tomeras91의 리뷰 의견을 반영한 것으로, vLLM 내의 다른 유사한 패턴에도 적용될 수 있는 좋은 관행입니다.

vllm/model_executor/layers/mamba/mamba_mixer2.py 변경 사항 (핵심)

이 파일은 PR의 핵심 변경 사항을 담고 있습니다. MambaMixer2 클래스에 _warmup_ssd_kernels 메서드가 추가되고, 모델의 프로파일링 단계에서 이 메서드가 호출됩니다.

Before:

 # ... (imports)

 # Added by the IBM Team, 2024


class MambaMixer2(PluggableLayer):
    # ...
    def __init__(
        # ...
    ):
        # ...

    def forward(
        # ...
    ):
        # ...
        if attn_metadata is None:
            # profile run
            hidden_states_B_C = (
                hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
            ).contiguous()

After:

 # ... (imports)
 from vllm.logger import init_logger
 # ...

 logger = init_logger(__name__)

 # Added by the IBM Team, 2024


class MambaMixer2(PluggableLayer):
    # ...
    def __init__(
        # ...
    ):
        # ...
        self._ssd_kernels_warmed_up = False

        # - get hidden_states, B and C after depthwise convolution.
        self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
            hidden_states_B_C,
            # ...

    def _warmup_ssd_kernels(self, projected_states: torch.Tensor) -> None:
        """Run a minimal SSD forward pass to trigger Triton autotuning
        while GPU memory is still plentiful (before SSM cache allocation).
        """
        if self._ssd_kernels_warmed_up:
            return
        self._ssd_kernels_warmed_up = True
        logger.info_once("Warming up Mamba2 SSD Triton kernels...")

        device = projected_states.device
        dtype = projected_states.dtype

        nheads = self.num_heads // self.tp_size
        ngroups = self.n_groups // self.tp_size
        headdim = self.head_dim
        dstate = self.ssm_state_size

        if self.model_config is None:
            return
        chunk_size = self.model_config.get_mamba_chunk_size()

        # Triton's autotuner includes tensor dtypes in its cache key,
        # so state_dtype must match what real inference uses.
        _, ssm_state_dtype = self.get_state_dtype()

        # SSD kernel autotune keys depend on dtype and head dimensions,
        # not on sequence length or batch size, so a single shape suffices.
        seqlen = chunk_size
        batch = 1
        nchunks = seqlen // chunk_size  # = 1

        x = torch.randn(seqlen, nheads, headdim, device=device, dtype=dtype)
        dt = torch.randn(seqlen, nheads, device=device, dtype=dtype)
        B = torch.randn(seqlen, ngroups, dstate, device=device, dtype=dtype)
        C = torch.randn(seqlen, ngroups, dstate, device=device, dtype=dtype)
        cu_seqlens = torch.tensor([0, seqlen], device=device, dtype=torch.int32)
        cu_chunk_seqlens = torch.tensor(
            [i * chunk_size for i in range(nchunks + 1)],
            device=device,
            dtype=torch.int32,
        )
        last_chunk_indices = torch.tensor(
            [nchunks - 1], device=device, dtype=torch.int32
        )
        seq_idx = torch.zeros(nchunks, device=device, dtype=torch.int32)
        out = torch.empty(seqlen, nheads, headdim, device=device, dtype=dtype)

        # Two kernels (_state_passing_fwd, _chunk_scan_fwd) use
        # HAS_INITSTATES as a constexpr, producing separate compiled
        # binaries. Warm up both code paths so neither triggers
        # JIT compilation during inference.
        for use_initial_states in (False, True):
            initial_states = (
                torch.randn(
                    batch,
                    nheads,
                    headdim,
                    dstate,
                    device=device,
                    dtype=ssm_state_dtype,
                )
                if use_initial_states
                else None
            )
            try:
                mamba_chunk_scan_combined_varlen(
                    x=x,
                    dt=dt,
                    A=self.A,
                    B=B,
                    C=C,
                    chunk_size=chunk_size,
                    cu_seqlens=cu_seqlens,
                    cu_chunk_seqlens=cu_chunk_seqlens,
                    last_chunk_indices=last_chunk_indices,
                    seq_idx=seq_idx,
                    out=out,
                    D=self.D,
                    z=None,
                    dt_bias=self.dt_bias,
                    initial_states=initial_states,
                    dt_softplus=True,
                    dt_limit=(0.0, float("inf")),
                    state_dtype=ssm_state_dtype,
                )
            except Exception:
                logger.warning(
                    "Mamba2 SSD kernel warmup failed for layer %s "
                    "(initial_states=%s). First inference may experience "
                    "latency spike or OOM due to autotuner.",
                    self.prefix,
                    use_initial_states,
                    exc_info=True,
                )

        logger.debug("Mamba2 SSD kernel warmup completed for layer %s", self.prefix)
        torch.accelerator.empty_cache()

    def conv_ssm_forward(
        self,
        projected_states: torch.Tensor,
        # ...
    ):
        # ...
        if attn_metadata is None:
            # V1 profile run -- warm up SSD kernels so that autotuning
            # completes before SSM cache allocation.
            self._warmup_ssd_kernels(projected_states)
            hidden_states_B_C = (
                hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
            ).contiguous()
  • 무엇이 왜 좋은가:
    1. _ssd_kernels_warmed_up 플래그: __init__에서 _ssd_kernels_warmed_up = False로 초기화하고, _warmup_ssd_kernels 메서드 내에서 이 플래그를 사용하여 중복 웜업을 방지합니다. 이는 GatedDeltaRule의 변경과 일관성을 유지하며, 효율적인 웜업을 보장합니다.
    2. _warmup_ssd_kernels 메서드: 이 메서드는 mamba_chunk_scan_combined_varlen 커널을 더미 텐서로 한 번 실행하여 Triton의 자동 튜닝을 미리 트리거합니다. 특히, projected_statesdevicedtype을 사용하여 실제 추론 환경과 동일한 조건으로 웜업을 수행합니다. tomeras91의 제안에 따라 torch.zeros 대신 torch.randn을 사용하여 커널의 제로 패스트 패스를 회피하고 실제 시나리오에 더 가깝게 웜업합니다.
    3. HAS_INITSTATES 처리: mamba_chunk_scan_combined_varlen 커널은 HAS_INITSTATES라는 constexpr에 따라 두 가지 다른 코드 경로(initial_states 유무)를 가집니다. 이 PR은 for use_initial_states in (False, True): 루프를 통해 두 경로 모두를 웜업하여, 어떤 경우에도 JIT 컴파일이 추론 중에 발생하지 않도록 합니다. tomeras91의 지적처럼, 이는 자동 튜닝 키의 일부가 아니라 JIT 컴파일을 위한 것입니다.
    4. conv_ssm_forward 호출: MambaMixer2conv_ssm_forward 메서드에서 attn_metadata is None인 경우(즉, V1 프로파일링 실행 중) _warmup_ssd_kernels를 호출합니다. 이는 SSM 캐시 할당 전에 웜업을 완료하여 GPU 메모리가 충분할 때 튜닝이 이루어지도록 합니다.
    5. 로깅 개선: logger.info_once를 사용하여 전체 모델에 대해 한 번만 웜업 시작 메시지를 출력하고, 각 레이어의 웜업 완료는 logger.debug로 출력하여 로그의 가독성을 높였습니다. 이는 tomeras91의 리뷰 의견을 반영한 것입니다.
    6. 예외 처리: 웜업 중 예외가 발생할 경우 경고 로그를 남겨, 문제가 발생했음을 알리고 첫 추론 시 지연 시간 스파이크나 OOM이 발생할 수 있음을 명시합니다.
    7. torch.accelerator.empty_cache(): 웜업 후 GPU 캐시를 비워 불필요한 메모리 점유를 방지합니다.

왜 이게 좋은가: 성능 수치와 일반적 교훈

이 PR의 가장 큰 장점은 첫 요청 지연 시간을 극적으로 줄여 사용자 경험을 크게 개선한다는 점입니다. 벤치마크 결과는 이를 명확히 보여줍니다.

벤치마크 결과 (H100 80GB, nvidia/NVIDIA-Nemotron-3-Nano-4B-BF16)

Metric Baseline (main) With Warmup Change
Model load time 30.0s 76.9s +46.9s (autotuning shifted here)
First request latency 31.343s 2.890s -28.5s (91% reduction)
Subsequent request avg 0.083s 0.083s No change
First / subsequent ratio 378x 35x 10.8x improvement
  • 모델 로드 시간 증가: 웜업으로 인해 모델 로드 시간이 30.0초에서 76.9초로 증가했습니다. 이는 Triton 자동 튜닝 비용이 서버 시작 단계로 이동했기 때문입니다. 하지만 이는 한 번만 발생하는 비용이며, 서비스 시작 준비 시간을 늘리는 대신 사용자에게 직접적인 영향을 주는 첫 요청 지연 시간을 줄이는 데 기여합니다.
  • 첫 요청 지연 시간 91% 감소: 가장 중요한 지표인 첫 요청 지연 시간이 31.343초에서 2.890초로 무려 91% 감소했습니다. 이는 서비스의 반응성을 획기적으로 개선합니다.
  • 후속 요청 지연 시간 변화 없음: 웜업은 첫 요청에만 영향을 미치며, 이미 튜닝된 커널을 사용하는 후속 요청의 성능에는 영향을 주지 않습니다. 이는 최적화가 의도한 대로 동작함을 의미합니다.
  • 첫 요청/후속 요청 비율 10.8배 개선: 이 비율은 첫 요청의 비효율성을 나타내는데, 378배에서 35배로 크게 개선되었습니다. 이는 시스템이 훨씬 더 빠르게 안정적인 성능을 제공할 수 있음을 의미합니다.

TRITON_PRINT_AUTOTUNING=1 환경 변수를 통해 확인한 결과, 웜업 브랜치에서는 모델 로드 후 SSD 커널 자동 튜닝 이벤트가 전혀 발생하지 않았습니다. 모든 튜닝이 초기화 중에 완료된 것입니다.

일반적 교훈

이 최적화는 딥러닝 모델 서빙 시스템에서 다음과 같은 중요한 교훈을 제공합니다.

  1. 지연 로딩(Lazy Loading)의 함정: Triton과 같은 고성능 라이브러리의 자동 튜닝 기능은 편리하지만, 첫 사용 시점에 발생하는 비용이 서비스의 사용자 경험을 저해할 수 있습니다. 특히 LLM처럼 첫 요청이 중요한 서비스에서는 이러한 지연 로딩을 경계해야 합니다.
  2. 초기화 단계로 비용 이동: 서비스 시작 시간은 다소 길어지더라도, 사용자에게 직접적인 영향을 미치는 첫 요청 지연 시간을 줄이는 것이 중요합니다. 비용을 초기화 단계로 옮기는 것은 이러한 트레이드오프를 현명하게 관리하는 전략입니다.
  3. 커널 웜업의 중요성: GPU 커널, 특히 자동 튜닝이 필요한 커널은 실제 추론에 앞서 더미 데이터를 사용하여 미리 웜업하는 것이 좋습니다. 이는 JIT 컴파일 및 튜닝 비용을 서비스 시작 시점으로 이동시켜 런타임 성능을 안정화합니다.
  4. 정확한 웜업 조건: 웜업 시 사용되는 텐서의 dtype, device, 그리고 커널의 constexpr 경로(예: HAS_INITSTATES)는 실제 추론과 동일하게 설정해야 합니다. Triton의 자동 튜닝 캐시 키는 이러한 요소들을 포함하기 때문입니다.
  5. 로깅과 디버깅: TRITON_PRINT_AUTOTUNING=1과 같은 환경 변수나 상세한 로깅은 최적화의 효과를 검증하고 문제를 진단하는 데 필수적입니다.

이 PR은 vLLM이 Mamba2 모델을 효율적으로 서빙하기 위한 중요한 단계이며, 고성능 딥러닝 추론 시스템을 구축하는 데 있어 실용적인 최적화 기법을 잘 보여줍니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글