[vllm] vLLM Mamba2 SSD 커널 웜업: 첫 요청 지연 시간 91% 감소의 비결
PR 링크: vllm-project/vllm#39822 상태: Merged | 변경: +0 / -0
들어가며
대규모 언어 모델(LLM)을 서빙할 때, 첫 번째 추론 요청의 지연 시간은 사용자 경험에 큰 영향을 미칩니다. 특히 Mamba2와 같은 새로운 아키텍처는 Triton과 같은 고성능 컴퓨팅 라이브러리를 활용하여 최적화된 커널을 사용하는데, 이 커널들이 '게으르게(lazily)' 동작하여 첫 요청 시점에 컴파일 및 튜닝되는 경우가 많습니다. 이는 첫 요청에서 상당한 지연 시간 스파이크를 유발하며, 실제 서비스 환경에서는 치명적인 단점이 될 수 있습니다.
vLLM 프로젝트의 이 PR은 이러한 문제를 해결하기 위해 Mamba2 SSD(State Space Model) 커널의 Triton 자동 튜닝을 서버 시작 단계로 옮겨, 첫 요청 지연 시간을 획기적으로 줄이는 최적화를 제안합니다. 이 글에서는 해당 PR의 코드 변경 사항을 분석하고, 이 최적화가 왜 중요하며 어떤 기술적 교훈을 주는지 살펴보겠습니다.
코드 분석: Mamba2 SSD 커널 웜업
이 PR의 핵심은 Mamba2 모델의 MambaMixer2 클래스에 _warmup_ssd_kernels() 메서드를 추가하고, vLLM의 프로파일링 단계에서 이 메서드를 호출하여 Triton 커널 자동 튜닝을 미리 수행하는 것입니다. 또한, ModelConfig의 get_mamba_chunk_size 메서드에서 Mamba1의 기본 청크 사이즈를 1024에서 2048로 변경하는 소소한 개선도 포함되어 있습니다.
vllm/config/model.py 변경 사항
이 파일에서는 ModelConfig 클래스의 get_mamba_chunk_size 메서드가 수정되었습니다. Mamba1 모델의 경우 명시적인 chunk_size가 없을 때 사용되는 기본값이 변경되었습니다.
Before:
def get_mamba_chunk_size(self) -> int | None:
"""
Returns the mamba chunk size if it exists
"""
if self.hf_text_config is not None:
chunk_size = getattr(self.hf_text_config, "chunk_size", None)
# Since Mamba1 does not have a chunk notion
# we use a default chunk size of 1024.
if chunk_size is None:
chunk_size = 1024
return chunk_size
After:
def get_mamba_chunk_size(self) -> int:
"""
Returns the mamba chunk size if it exists
"""
if self.hf_text_config is not None:
chunk_size = getattr(self.hf_text_config, "chunk_size", None)
# Since Mamba1 does not have a chunk notion
# we use a default chunk size of 2048.
if chunk_size is None:
chunk_size = 2048
return chunk_size
- 무엇이 왜 좋은가:
get_mamba_chunk_size의 반환 타입 힌트가int | None에서int로 변경되어, 항상int를 반환함을 명확히 했습니다. 또한, Mamba1의 기본chunk_size가 1024에서 2048로 변경되었습니다. 이는 Mamba2 모델의 웜업 시 사용되는chunk_size와 일관성을 유지하고, 더 큰 청크 사이즈가 특정 시나리오에서 성능 이점을 제공할 수 있기 때문입니다. 리뷰어tomeras91의 지적처럼, 문서와 코드의 불일치를 수정하고 더 합리적인 기본값을 설정한 것입니다.
vllm/model_executor/layers/mamba/gdn_linear_attn.py 변경 사항
이 파일에서는 GatedDeltaRule 클래스에 _prefill_kernels_warmed_up 플래그를 초기화하는 방식이 개선되었습니다. 이는 MambaMixer2의 웜업 로직과 일관성을 맞추기 위한 변경입니다.
Before:
def _warmup_prefill_kernels(self, qkv_or_qkvz: torch.Tensor, v_dim: int) -> None:
# ...
if hasattr(self, "_prefill_kernels_warmed_up"):
return
self._prefill_kernels_warmed_up = True
# ...
After:
def __init__(
# ...
):
# ...
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
self._prefill_kernels_warmed_up = False
# ...
def _warmup_prefill_kernels(self, qkv_or_qkvz: torch.Tensor, v_dim: int) -> None:
# ...
if self._prefill_kernels_warmed_up:
return
self._prefill_kernels_warmed_up = True
# ...
- 무엇이 왜 좋은가:
_prefill_kernels_warmed_up속성을__init__메서드에서 명시적으로False로 초기화하도록 변경했습니다. 이전에는hasattr를 통해 속성 존재 여부를 확인했지만, 명시적인 초기화는 코드의 가독성과 예측 가능성을 높입니다. 이는tomeras91의 리뷰 의견을 반영한 것으로, vLLM 내의 다른 유사한 패턴에도 적용될 수 있는 좋은 관행입니다.
vllm/model_executor/layers/mamba/mamba_mixer2.py 변경 사항 (핵심)
이 파일은 PR의 핵심 변경 사항을 담고 있습니다. MambaMixer2 클래스에 _warmup_ssd_kernels 메서드가 추가되고, 모델의 프로파일링 단계에서 이 메서드가 호출됩니다.
Before:
# ... (imports)
# Added by the IBM Team, 2024
class MambaMixer2(PluggableLayer):
# ...
def __init__(
# ...
):
# ...
def forward(
# ...
):
# ...
if attn_metadata is None:
# profile run
hidden_states_B_C = (
hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
).contiguous()
After:
# ... (imports)
from vllm.logger import init_logger
# ...
logger = init_logger(__name__)
# Added by the IBM Team, 2024
class MambaMixer2(PluggableLayer):
# ...
def __init__(
# ...
):
# ...
self._ssd_kernels_warmed_up = False
# - get hidden_states, B and C after depthwise convolution.
self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
# ...
def _warmup_ssd_kernels(self, projected_states: torch.Tensor) -> None:
"""Run a minimal SSD forward pass to trigger Triton autotuning
while GPU memory is still plentiful (before SSM cache allocation).
"""
if self._ssd_kernels_warmed_up:
return
self._ssd_kernels_warmed_up = True
logger.info_once("Warming up Mamba2 SSD Triton kernels...")
device = projected_states.device
dtype = projected_states.dtype
nheads = self.num_heads // self.tp_size
ngroups = self.n_groups // self.tp_size
headdim = self.head_dim
dstate = self.ssm_state_size
if self.model_config is None:
return
chunk_size = self.model_config.get_mamba_chunk_size()
# Triton's autotuner includes tensor dtypes in its cache key,
# so state_dtype must match what real inference uses.
_, ssm_state_dtype = self.get_state_dtype()
# SSD kernel autotune keys depend on dtype and head dimensions,
# not on sequence length or batch size, so a single shape suffices.
seqlen = chunk_size
batch = 1
nchunks = seqlen // chunk_size # = 1
x = torch.randn(seqlen, nheads, headdim, device=device, dtype=dtype)
dt = torch.randn(seqlen, nheads, device=device, dtype=dtype)
B = torch.randn(seqlen, ngroups, dstate, device=device, dtype=dtype)
C = torch.randn(seqlen, ngroups, dstate, device=device, dtype=dtype)
cu_seqlens = torch.tensor([0, seqlen], device=device, dtype=torch.int32)
cu_chunk_seqlens = torch.tensor(
[i * chunk_size for i in range(nchunks + 1)],
device=device,
dtype=torch.int32,
)
last_chunk_indices = torch.tensor(
[nchunks - 1], device=device, dtype=torch.int32
)
seq_idx = torch.zeros(nchunks, device=device, dtype=torch.int32)
out = torch.empty(seqlen, nheads, headdim, device=device, dtype=dtype)
# Two kernels (_state_passing_fwd, _chunk_scan_fwd) use
# HAS_INITSTATES as a constexpr, producing separate compiled
# binaries. Warm up both code paths so neither triggers
# JIT compilation during inference.
for use_initial_states in (False, True):
initial_states = (
torch.randn(
batch,
nheads,
headdim,
dstate,
device=device,
dtype=ssm_state_dtype,
)
if use_initial_states
else None
)
try:
mamba_chunk_scan_combined_varlen(
x=x,
dt=dt,
A=self.A,
B=B,
C=C,
chunk_size=chunk_size,
cu_seqlens=cu_seqlens,
cu_chunk_seqlens=cu_chunk_seqlens,
last_chunk_indices=last_chunk_indices,
seq_idx=seq_idx,
out=out,
D=self.D,
z=None,
dt_bias=self.dt_bias,
initial_states=initial_states,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
state_dtype=ssm_state_dtype,
)
except Exception:
logger.warning(
"Mamba2 SSD kernel warmup failed for layer %s "
"(initial_states=%s). First inference may experience "
"latency spike or OOM due to autotuner.",
self.prefix,
use_initial_states,
exc_info=True,
)
logger.debug("Mamba2 SSD kernel warmup completed for layer %s", self.prefix)
torch.accelerator.empty_cache()
def conv_ssm_forward(
self,
projected_states: torch.Tensor,
# ...
):
# ...
if attn_metadata is None:
# V1 profile run -- warm up SSD kernels so that autotuning
# completes before SSM cache allocation.
self._warmup_ssd_kernels(projected_states)
hidden_states_B_C = (
hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
).contiguous()
- 무엇이 왜 좋은가:
_ssd_kernels_warmed_up플래그:__init__에서_ssd_kernels_warmed_up = False로 초기화하고,_warmup_ssd_kernels메서드 내에서 이 플래그를 사용하여 중복 웜업을 방지합니다. 이는GatedDeltaRule의 변경과 일관성을 유지하며, 효율적인 웜업을 보장합니다._warmup_ssd_kernels메서드: 이 메서드는mamba_chunk_scan_combined_varlen커널을 더미 텐서로 한 번 실행하여 Triton의 자동 튜닝을 미리 트리거합니다. 특히,projected_states의device와dtype을 사용하여 실제 추론 환경과 동일한 조건으로 웜업을 수행합니다.tomeras91의 제안에 따라torch.zeros대신torch.randn을 사용하여 커널의 제로 패스트 패스를 회피하고 실제 시나리오에 더 가깝게 웜업합니다.HAS_INITSTATES처리:mamba_chunk_scan_combined_varlen커널은HAS_INITSTATES라는constexpr에 따라 두 가지 다른 코드 경로(initial_states 유무)를 가집니다. 이 PR은for use_initial_states in (False, True):루프를 통해 두 경로 모두를 웜업하여, 어떤 경우에도 JIT 컴파일이 추론 중에 발생하지 않도록 합니다.tomeras91의 지적처럼, 이는 자동 튜닝 키의 일부가 아니라 JIT 컴파일을 위한 것입니다.conv_ssm_forward호출:MambaMixer2의conv_ssm_forward메서드에서attn_metadata is None인 경우(즉, V1 프로파일링 실행 중)_warmup_ssd_kernels를 호출합니다. 이는 SSM 캐시 할당 전에 웜업을 완료하여 GPU 메모리가 충분할 때 튜닝이 이루어지도록 합니다.- 로깅 개선:
logger.info_once를 사용하여 전체 모델에 대해 한 번만 웜업 시작 메시지를 출력하고, 각 레이어의 웜업 완료는logger.debug로 출력하여 로그의 가독성을 높였습니다. 이는tomeras91의 리뷰 의견을 반영한 것입니다. - 예외 처리: 웜업 중 예외가 발생할 경우 경고 로그를 남겨, 문제가 발생했음을 알리고 첫 추론 시 지연 시간 스파이크나 OOM이 발생할 수 있음을 명시합니다.
torch.accelerator.empty_cache(): 웜업 후 GPU 캐시를 비워 불필요한 메모리 점유를 방지합니다.
왜 이게 좋은가: 성능 수치와 일반적 교훈
이 PR의 가장 큰 장점은 첫 요청 지연 시간을 극적으로 줄여 사용자 경험을 크게 개선한다는 점입니다. 벤치마크 결과는 이를 명확히 보여줍니다.
벤치마크 결과 (H100 80GB, nvidia/NVIDIA-Nemotron-3-Nano-4B-BF16)
| Metric | Baseline (main) | With Warmup | Change |
|---|---|---|---|
| Model load time | 30.0s | 76.9s | +46.9s (autotuning shifted here) |
| First request latency | 31.343s | 2.890s | -28.5s (91% reduction) |
| Subsequent request avg | 0.083s | 0.083s | No change |
| First / subsequent ratio | 378x | 35x | 10.8x improvement |
- 모델 로드 시간 증가: 웜업으로 인해 모델 로드 시간이 30.0초에서 76.9초로 증가했습니다. 이는 Triton 자동 튜닝 비용이 서버 시작 단계로 이동했기 때문입니다. 하지만 이는 한 번만 발생하는 비용이며, 서비스 시작 준비 시간을 늘리는 대신 사용자에게 직접적인 영향을 주는 첫 요청 지연 시간을 줄이는 데 기여합니다.
- 첫 요청 지연 시간 91% 감소: 가장 중요한 지표인 첫 요청 지연 시간이 31.343초에서 2.890초로 무려 91% 감소했습니다. 이는 서비스의 반응성을 획기적으로 개선합니다.
- 후속 요청 지연 시간 변화 없음: 웜업은 첫 요청에만 영향을 미치며, 이미 튜닝된 커널을 사용하는 후속 요청의 성능에는 영향을 주지 않습니다. 이는 최적화가 의도한 대로 동작함을 의미합니다.
- 첫 요청/후속 요청 비율 10.8배 개선: 이 비율은 첫 요청의 비효율성을 나타내는데, 378배에서 35배로 크게 개선되었습니다. 이는 시스템이 훨씬 더 빠르게 안정적인 성능을 제공할 수 있음을 의미합니다.
TRITON_PRINT_AUTOTUNING=1 환경 변수를 통해 확인한 결과, 웜업 브랜치에서는 모델 로드 후 SSD 커널 자동 튜닝 이벤트가 전혀 발생하지 않았습니다. 모든 튜닝이 초기화 중에 완료된 것입니다.
일반적 교훈
이 최적화는 딥러닝 모델 서빙 시스템에서 다음과 같은 중요한 교훈을 제공합니다.
- 지연 로딩(Lazy Loading)의 함정: Triton과 같은 고성능 라이브러리의 자동 튜닝 기능은 편리하지만, 첫 사용 시점에 발생하는 비용이 서비스의 사용자 경험을 저해할 수 있습니다. 특히 LLM처럼 첫 요청이 중요한 서비스에서는 이러한 지연 로딩을 경계해야 합니다.
- 초기화 단계로 비용 이동: 서비스 시작 시간은 다소 길어지더라도, 사용자에게 직접적인 영향을 미치는 첫 요청 지연 시간을 줄이는 것이 중요합니다. 비용을 초기화 단계로 옮기는 것은 이러한 트레이드오프를 현명하게 관리하는 전략입니다.
- 커널 웜업의 중요성: GPU 커널, 특히 자동 튜닝이 필요한 커널은 실제 추론에 앞서 더미 데이터를 사용하여 미리 웜업하는 것이 좋습니다. 이는 JIT 컴파일 및 튜닝 비용을 서비스 시작 시점으로 이동시켜 런타임 성능을 안정화합니다.
- 정확한 웜업 조건: 웜업 시 사용되는 텐서의
dtype,device, 그리고 커널의constexpr경로(예:HAS_INITSTATES)는 실제 추론과 동일하게 설정해야 합니다. Triton의 자동 튜닝 캐시 키는 이러한 요소들을 포함하기 때문입니다. - 로깅과 디버깅:
TRITON_PRINT_AUTOTUNING=1과 같은 환경 변수나 상세한 로깅은 최적화의 효과를 검증하고 문제를 진단하는 데 필수적입니다.
이 PR은 vLLM이 Mamba2 모델을 효율적으로 서빙하기 위한 중요한 단계이며, 고성능 딥러닝 추론 시스템을 구축하는 데 있어 실용적인 최적화 기법을 잘 보여줍니다.
참고 자료
- https://github.com/vllm-project/vllm/blob/main/vllm/config/model.py
- https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/gdn_linear_attn.py
- https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_mixer2.py
- https://pytorch.org/docs/stable/generated/torch.randn.html
- https://pytorch.org/docs/stable/generated/torch.tensor.html
- https://pytorch.org/docs/stable/generated/torch.empty.html
- https://pytorch.org/docs/stable/generated/torch.accelerator.empty_cache.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [vllm] vLLM의 Triton 통합 어텐션 커널에 Tensor Descriptor 최적화 도입
- [flashinfer] FlashInfer Mamba SSU 커널 최적화: Async State Prefetching과 Vectorized Load를 통한 성능 혁신
- [vllm] vLLM, DeepSeek-V4 K 캐시 커널 최적화: CuteDSL 도입으로 성능 향상
- [vllm] [vLLM] ROCm 환경에서의 DeepSeek-V2/V3 성능 극대화를 위한 MLA 최적화 분석
- [vllm] vLLM의 첫 추론 지연 문제 해결: forward_native 샘플러 커널 웜업 최적화
PR Analysis 의 다른글
- 이전글 [onnxruntime] [ONNX Runtime] PagedAttention의 FA 경로 최적화 및 정확성 개선
- 현재글 : [vllm] vLLM Mamba2 SSD 커널 웜업: 첫 요청 지연 시간 91% 감소의 비결
- 다음글 [sglang] NPU 성능 향상을 위한 causal_conv1d_update_v2 도입
댓글