본문으로 건너뛰기

[vllm] vLLM의 Mamba 모델 성능 최적화: Conv State 레이아웃 개선

PR 링크: vllm-project/vllm#37416 상태: Merged | 변경: +None / -None

들어가며

vLLM에서 Mamba 모델을 구동할 때, 기존의 Conv State 레이아웃은 분산 환경(Disaggregated scenarios)에서 비효율적인 메모리 접근과 복잡한 셔딩(sharding) 문제를 야기했습니다. 특히, Conv State가 (state_len, dim) 형태인 SD(State-Dim) 레이아웃으로 저장되어 있어, TP(Tensor Parallelism) 샤드를 선택하거나 전송할 때 추가적인 버퍼링과 전치(transpose) 연산이 필수적이었습니다. 본 PR은 이를 (dim, state_len) 형태인 DS(Dim-State) 레이아웃으로 변경하여 성능을 최적화하고, HeterogeneousTP를 위한 기반을 마련했습니다.

코드 분석

1. 레이아웃 설정 및 환경 변수 추가 (vllm/envs.py)

새로운 환경 변수 VLLM_SSM_CONV_STATE_LAYOUT을 도입하여 사용자가 SDDS 레이아웃을 선택할 수 있도록 했습니다.

# vllm/envs.py
"VLLM_SSM_CONV_STATE_LAYOUT": env_with_choices(
    "VLLM_SSM_CONV_STATE_LAYOUT", None, ["SD", "DS"]
),

2. 연산 로직의 유연한 대응 (vllm/model_executor/layers/mamba/mamba_mixer2.py 등)

기존에는 모든 상황에서 .transpose(-1, -2)를 호출하여 레이아웃을 맞췄으나, 이제는 is_conv_state_dim_first() 함수를 통해 레이아웃을 확인하고 필요한 경우에만 전치 연산을 수행합니다.

Before:

conv_state = self_kv_cache[0].transpose(-1, -2)

After:

conv_state = (
    self.kv_cache[0]
    if is_conv_state_dim_first()
    else self.kv_cache[0].transpose(-1, -2)
)

이 변경을 통해 DS 레이아웃 사용 시 불필요한 메모리 복사 및 전치 연산을 제거했습니다.

왜 이게 좋은가

성능 향상

벤치마크 결과, DS 레이아웃은 특히 TTFT(Time To First Token) 측면에서 기존 SD 레이아웃 대비 최대 약 1.5배의 성능 향상을 보였습니다. 이는 메모리 레이아웃이 연산 커널의 데이터 접근 패턴과 일치하게 되어 캐시 효율이 극대화되었기 때문입니다.

기술적 교훈

  1. 데이터 레이아웃의 중요성: 텐서의 차원 순서(Layout)는 단순히 데이터의 나열이 아니라, 하드웨어 가속기(GPU)의 메모리 대역폭 활용도와 직결됩니다. dim을 첫 번째 차원으로 두면 TP 샤딩 시 연속적인 메모리 읽기가 가능해져 오버헤드가 줄어듭니다.
  2. 추상화를 통한 유연성: is_conv_state_dim_first()와 같은 유틸리티 함수를 통해 하드웨어/설정별로 최적화된 경로를 선택하게 함으로써, 기존 코드의 호환성을 유지하면서도 성능 개선을 이끌어낼 수 있습니다.

리뷰어 피드백 반영

리뷰 과정에서 align 모드와 관련된 프리픽스 캐싱 이슈가 논의되었으며, 이는 향후 별도 이슈(#38898)를 통해 추적하기로 했습니다. 또한, 성능 측정 시 발생할 수 있는 지터(jitter)를 제거하기 위해 벤치마크 환경을 정밀하게 조정하는 등, 단순한 코드 변경을 넘어 실제 프로덕션 환경에서의 안정성을 검증하는 과정을 거쳤습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글