[sglang] Mamba GDN의 컨볼루션 캐시 최적화: 메모리 사용량 절반으로 줄이기
PR 링크: sgl-project/sglang#28302 상태: Merged | 변경: +283 / -20
들어가며
최근 LLM 개발에서 Mamba와 같은 하이브리드 선형 어텐션 모델이 주목받고 있습니다. 이러한 모델들은 기존 트랜스포머의 계산 복잡성을 개선하면서도 강력한 성능을 유지합니다. 특히, Mamba의 스페셜티브 디코딩(speculative decoding) 과정에서 발생하는 컨볼루션(conv) 연산은 성능에 중요한 영향을 미칩니다. SGLang 레포지토리의 이 PR은 Mamba 및 GDN(Generalized Decoding Network) 모델에서 스페셜티브 디코딩 시 사용되는 컨볼루션 중간 상태 캐시(intermediate_conv_window)의 메모리 사용량을 절반으로 줄이는 혁신적인 최적화를 소개합니다.
기존 방식에서는 각 드래프트 토큰마다 이전 K-1개의 입력에 대한 컨볼루션 상태 창(conv-state window)을 저장했습니다. 이는 [num_layers, slots, draft_tokens, conv_dim, K-1] 크기의 캐시를 필요로 했는데, 연속된 드래프트 토큰들의 창은 K-2만큼 겹치기 때문에 상당한 메모리 낭비가 발생했습니다. 이 PR은 이러한 중복을 제거하여 메모리 효율성을 크게 향상시킵니다.
코드 분석
이번 최적화는 주로 sglang/srt/layers/attention/hybrid_linear_attn_backend.py와 sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py, 그리고 sglang/srt/mem_cache/memory_pool.py 파일에서 이루어졌습니다.
1. sglang/srt/mem_cache/memory_pool.py: 슬라이딩 윈도우 레이아웃 도입
가장 핵심적인 변경은 MambaPool 클래스의 __init__ 메서드에서 intermediate_conv_window_cache를 초기화하는 부분입니다. 기존에는 각 드래프트 토큰마다 독립적인 컨볼루션 창을 할당했지만, 이제는 CUDA 환경에서 슬라이딩 윈도우 레이아웃을 사용하여 메모리를 절반으로 줄입니다.
Before:
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache = [
torch.zeros(
size=(
num_mamba_layers,
spec_state_size + 1,
speculative_num_draft_tokens,
conv_shape[0],
conv_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
for conv_shape in conv_state_shape
]
After:
# On CUDA (Triton conv kernel + Triton scatter) we use a
# *deduplicated sliding-window* layout: consecutive draft tokens'
# (K-1)-wide windows overlap by (K-2), so instead of D separate
# [dim, K-1] windows we store one shared [dim, D+K-2] buffer per
# (layer, slot) and expose an overlapping `as_strided` view of
# logical shape [num_layers, size+1, draft_tokens, dim, K-1] where
# step `t`'s window is the slice shared[..., :, t:t+K-1]. This
# halves the conv-intermediate footprint (D*(K-1) -> D+K-2 columns)
# with no numerical change: both the conv kernel write (idempotent
# overlapping stores) and `fused_conv_window_scatter_with_mask`
# consume the view
intermediate_conv_window_cache = [
torch.empty(
size=(
num_mamba_layers,
spec_state_size + 1,
conv_shape[0] + conv_shape[1] - 2,
conv_shape[0],
conv_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
for conv_shape in conv_state_shape
]
핵심 변화는 speculative_num_draft_tokens 차원이 conv_shape[0] + conv_shape[1] - 2로 줄어든 것입니다. 이는 D개의 (conv_dim, K-1) 창 대신, D+K-2개의 열을 가진 공유 버퍼를 사용하고, as_strided 뷰를 통해 논리적인 [num_layers, slots, draft_tokens, conv_dim, K-1] 형태를 제공하는 방식입니다. 예를 들어 D=4, K=4일 때, 기존에는 4 * (4-1) = 12개의 열이 필요했지만, 개선 후에는 4 + 4 - 2 = 6개의 열만 필요하게 되어 정확히 0.5배의 메모리 절약 효과를 얻습니다.
또한, conv_window_dedup_enabled 함수가 추가되어 이 최적화가 NPU나 CPU에서는 적용되지 않고, 스페셜티브 디코딩의 speculative_eagle_topk 값이 1 이하일 때 (즉, 선형적인 드래프트 체인일 때)만 활성화되도록 조건을 명확히 했습니다. 이는 리뷰어 BBuf의 피드백을 반영한 것입니다.
2. sglang/srt/layers/attention/hybrid_linear_attn_backend.py: 새로운 Scatter 커널 사용
update_mamba_state_after_mtp_verify 함수에서는 스페셜티브 디코딩 후 Mamba 상태를 업데이트하는 로직이 수정되었습니다. 기존에는 fused_mamba_state_scatter_with_mask를 사용했지만, 이제는 새로운 fused_conv_window_scatter_with_mask를 사용합니다.
Before:
fused_mamba_state_scatter_with_mask(
conv_states,
intermediate_conv_window_cache,
state_indices_tensor,
last_correct_step_indices,
)
After:
# conv intermediate uses the deduplicated sliding-window (overlapping)
# layout, so it needs the strided-read scatter variant.
fused_conv_window_scatter_with_mask(
conv_states,
intermediate_conv_window_cache,
state_indices_tensor,
last_correct_step_indices,
)
이 변경은 슬라이딩 윈도우 레이아웃으로 인해 intermediate_conv_window_cache가 더 이상 연속적이지 않기 때문에 필요합니다. 새로운 fused_conv_window_scatter_with_mask 함수는 이 비연속적인 뷰를 올바르게 처리하도록 설계되었습니다.
3. sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py: fused_conv_window_scatter_with_mask 구현
이 파일에는 새로운 fused_conv_window_scatter_with_mask 함수와 관련 커널(_fused_conv_window_scatter_with_mask_kernel)이 구현되었습니다. 이 커널은 슬라이딩 윈도우 레이아웃의 비연속적인 소스 텐서에서 데이터를 읽어와 연속적인 대상 텐서에 쓰는 역할을 합니다.
주요 로직은 다음과 같습니다:
- 비연속성 처리:
_fused_conv_window_scatter_with_mask_kernel은src텐서의 각(dim, win)요소를as_strided뷰의 스트라이드를 통해 개별적으로 인덱싱합니다. 이는 기존의fused_mamba_state_scatter_with_mask가 연속적인 행 전체를 복사하는 것과 대조적입니다. - 메모리 절약: 공유 버퍼(
[dim, D+K-2])를 사용하고as_strided뷰를 통해 접근함으로써,D*(K-1)크기의 메모리 대신D+K-2크기의 메모리만 사용하게 됩니다. - 안정성: 컨볼루션 커널은 겹치는 열에 동일한 입력을 쓰기 때문에, 이 변경은 수치적으로 동일하며 경쟁 상태(race condition)도 발생하지 않습니다. 이는 PR 설명에서 'idempotent and race-safe'라고 언급된 부분입니다.
리뷰어 BBuf는 dst 텐서가 반드시 연속적이어야 한다는 검증 로직을 추가하도록 제안했으며, 이는 fused_conv_window_scatter_with_mask 함수 내에 반영되었습니다.
4. sglang/srt/layers/attention/hybrid_linear_attn_backend.py: speculative_eagle_topk 파라미터 전달
DecodeReqToTokenPool.__init__ 호출 시 speculative_eagle_topk 파라미터가 전달되도록 __init__ 메서드가 수정되었습니다. 이는 conv_window_dedup_enabled 함수에서 이 값을 사용하여 최적화 적용 여부를 결정하기 위함입니다.
Before:
DecodeReqToTokenPool.__init__(
# ... other args
speculative_num_draft_tokens=speculative_num_draft_tokens,
)
After:
DecodeReqToTokenPool.__init__(
# ... other args
speculative_num_draft_tokens=speculative_num_draft_tokens,
speculative_eagle_topk=speculative_eagle_topk,
)
왜 이게 좋은가?
이 PR의 가장 큰 장점은 메모리 사용량을 절반으로 줄였다는 점입니다. 이는 특히 대규모 모델을 서비스하거나 메모리 제약이 있는 환경에서 매우 중요합니다.
-
성능 향상 (간접적): 직접적인 연산 속도 향상은 아니지만, 메모리 사용량 감소는 다음과 같은 간접적인 이점을 제공합니다:
- 더 높은 동시성: 동일한 하드웨어에서 더 많은 요청을 동시에 처리할 수 있습니다.
- OOM(Out-Of-Memory) 위험 감소: 모델 실행 중 메모리 부족으로 인한 실패 가능성을 줄입니다.
- 캐시 효율성 증대: 확보된 메모리 공간을 다른 용도로 활용할 가능성을 열어둡니다 (후속 작업으로 제안됨).
-
수치적 무결성: PR 설명에 명시된 바와 같이, 이 변경은 수치적으로 손실이 없습니다 (
max_abs_diff = 0.0). CPU/Numpy 테스트, GPU 비트-정확도 테스트, 그리고 실제 E2E 바이트 동일성 테스트를 통해 검증되었습니다. -
일반화 가능성: 이 최적화는 GDN뿐만 아니라, 스페셜티브 디코딩을 사용하는 모든 하이브리드 선형 어텐션 모델(특히 짧은 인과 컨볼루션을 사용하는 경우)에 적용될 수 있습니다. KDA와 같은 미래 모델에도 자동으로 적용될 수 있습니다.
-
코드 명확성 및 견고함: 리뷰어들의 피드백을 반영하여, 최적화 적용 조건을 명확히 하고(NPU/CPU 제외,
topk <= 1), 관련 검증 로직을 추가하여 코드의 견고성을 높였습니다.
벤치마크 결과:
| Model | TP | slots | dense | dedup |
|---|---|---|---|---|
| Qwen3.5-0.8B | 1 | 1183 | 0.04 GB | 0.02 GB |
| Qwen3.5-35B-A3B | 2 | 745 | 0.09 GB/rank | 0.05 GB/rank |
이 표는 intermediate_conv_window 할당량만 비교한 것으로, 모델 크기와 동시성(slots)이 증가할수록 절약되는 메모리 비율도 커짐을 보여줍니다.
결론
이 PR은 Mamba 및 GDN 모델의 스페셜티브 디코딩 과정에서 발생하는 컨볼루션 캐시의 메모리 사용량을 절반으로 줄이는 매우 효과적인 최적화를 성공적으로 구현했습니다. 슬라이딩 윈도우 레이아웃과 새로운 Triton 커널을 통해 메모리 효율성을 높이면서도 수치적 정확성을 유지했으며, 이는 LLM 서빙 환경에서 중요한 이점을 제공합니다. 이 최적화는 코드의 명확성과 견고성을 높이는 리뷰 과정까지 거쳐 완성되었습니다.
참고 자료
- https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py
- https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/memory_pool.py
- https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [sglang] SGLang의 Linear-Attention 성능 최적화: int8 체크포인트 풀 도입
- 현재글 : [sglang] Mamba GDN의 컨볼루션 캐시 최적화: 메모리 사용량 절반으로 줄이기
- 다음글 [vllm] vLLM Mooncake KV 오프로딩 최적화: 불필요한 KV 조회 건너뛰기
댓글