[SGLang] Hybrid Attention: Dense-Sparse 동적 전환 전략
들어가며
LLM 서빙에서 Prefill과 Decode는 연산 특성이 완전히 다르다. Prefill은 긴 시퀀스를 한 번에 처리하므로 compute-bound이고, Decode는 토큰 하나씩 생성하므로 memory-bound이다. 같은 어텐션 커널을 양쪽에 쓰면 한쪽에서 반드시 비효율이 발생한다.
SGLang의 HybridAttnBackend는 이 문제를 해결한다. Prefill 단계에는 Dense 어텐션 백엔드를, Decode 단계에는 Sparse 어텐션 백엔드를 동적으로 전환하여 각 단계에 최적화된 커널을 사용한다. 이 글에서는 python/sglang/srt/layers/attention/hybrid_attn_backend.py를 분석한다.
구조도
┌─────────────────────────────────────┐
│ HybridAttnBackend │
│ ┌───────────┐ ┌───────────────┐ │
│ │ prefill │ │ decode │ │
│ │ backend │ │ backend │ │
│ └─────┬─────┘ └──────┬────────┘ │
│ │ │ │
│ ┌─────▼────────────────▼────────┐ │
│ │ _select_backend() │ │
│ │ ForwardMode 기반 동적 선택 │ │
│ └───────────────────────────────┘ │
└─────────────────────────────────────┘
│
▼
┌──────────────────────────────┐
│ ForwardMode 분기 │
├──────────────────────────────┤
│ decode_or_idle → decode_backend │
│ target_verify → spec_mode 판단 │
│ draft_extend → spec_mode 판단 │
│ prefill/extend → prefill_backend │
└──────────────────────────────┘
HybridAttnBackend는 두 개의 실제 백엔드를 내부에 보유하고, ForwardMode에 따라 적절한 것을 선택한다.
핵심 코드 분석
생성자: 두 백엔드를 주입받는 구조
class HybridAttnBackend(AttentionBackend):
def __init__(
self,
model_runner: ModelRunner,
prefill_backend: AttentionBackend,
decode_backend: AttentionBackend,
):
self.model_runner = model_runner
self.prefill_backend = prefill_backend
self.decode_backend = decode_backend
self.data_type = model_runner.kv_cache_dtype
생성 시점에 Prefill용과 Decode용 백엔드를 각각 주입받는다. 예를 들어 Prefill에는 FlashAttention, Decode에는 Triton Flash Decoding을 조합할 수 있다. 이 조합은 HybridAttnBackend를 사용하는 상위 코드에서 결정된다.
_select_backend: ForwardMode 기반 동적 전환
def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
if forward_mode.is_decode_or_idle():
return self.decode_backend
elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
return (
self.decode_backend
if self.model_runner.server_args.speculative_attention_mode == "decode"
else self.prefill_backend
)
else:
return self.prefill_backend
전환 로직의 핵심이다. 세 가지 경로로 분기한다.
- Decode/Idle: 항상 decode_backend 사용
- Target Verify / Draft Extend: Speculative Decoding 설정에 따라 결정.
speculative_attention_mode가"decode"면 decode_backend, 아니면 prefill_backend를 사용 - Prefill/Extend: 항상 prefill_backend 사용
Speculative Decoding에서 Target Verify는 여러 토큰을 한 번에 검증하므로, Prefill처럼 처리하는 것이 기본 전략이다. 하지만 토큰 수가 적으면 Decode 커널이 더 효율적일 수 있어 사용자가 선택할 수 있도록 설정을 열어둔다.
forward: 선형 어텐션까지 지원하는 범용 위임
def forward(
self,
q: Optional[torch.Tensor] = None,
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
layer: Optional[RadixAttention] = None,
forward_batch: Optional[ForwardBatch] = None,
save_kv_cache: bool = True,
*,
mixed_qkv: Optional[torch.Tensor] = None, # For linear attention
a: Optional[torch.Tensor] = None,
b: Optional[torch.Tensor] = None,
**kwargs,
):
backend = self._select_backend(forward_batch.forward_mode)
if mixed_qkv is not None:
return backend.forward(
layer=layer, forward_batch=forward_batch,
save_kv_cache=save_kv_cache,
mixed_qkv=mixed_qkv, a=a, b=b, **kwargs,
)
return backend.forward(q, k, v, layer, forward_batch, save_kv_cache, **kwargs)
forward 메서드는 일반 어텐션(q, k, v)과 선형 어텐션(mixed_qkv, a, b)을 모두 처리한다. mixed_qkv가 전달되면 선형 어텐션 경로로, 아니면 일반 어텐션 경로로 위임한다. 이 설계로 하이브리드 모델(Transformer + Mamba 레이어 혼합)도 동일한 백엔드 인터페이스로 서빙할 수 있다.
CUDA Graph 초기화: 양쪽 백엔드를 동시에 준비
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
if (
self.model_runner.server_args.speculative_algorithm is not None
and self.model_runner.server_args.speculative_attention_mode == "prefill"
):
self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
CUDA Graph는 커널 호출을 캡처하여 재실행하는 최적화 기법이다. Decode 백엔드는 항상 초기화한다. Prefill 백엔드는 Speculative Decoding이 활성화되고 prefill 모드일 때만 초기화한다. 이는 Target Verify가 CUDA Graph로 실행될 수 있기 때문이다.
forward_decode / forward_extend: 직접 위임
def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs):
return self.decode_backend.forward_decode(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs):
backend = self._select_backend(forward_batch.forward_mode)
return backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
forward_decode는 항상 decode_backend에 직접 위임한다. forward_extend는 _select_backend로 선택한 후 위임한다. Target Verify 모드에서 extend가 호출될 수 있으므로, extend에서도 동적 전환이 필요하다.
설계 근거: 왜 런타임 전환인가
| 접근법 | 장점 | 단점 |
|---|---|---|
| 단일 백엔드 | 단순한 구조 | Prefill/Decode 중 하나에서 비효율 |
| 컴파일 타임 분기 | 오버헤드 없음 | 백엔드 조합 변경 시 재빌드 필요 |
| 런타임 전환 | 유연한 조합, 설정으로 변경 | 분기 오버헤드 (무시 가능) |
SGLang은 런타임 전환을 선택했다. _select_backend의 분기 비용은 Python if 문 하나이므로, GPU 커널 실행 시간 대비 완전히 무시할 수 있다. 대신 FlashInfer + Triton, FlashAttention + Custom Decode 등 다양한 조합을 서버 시작 시 설정만으로 바꿀 수 있다.
관련 포스트
- Triton Attention 커널: Python으로 작성하는 GPU 커널
- Mamba (SSM): 선형 시간 복잡도 시퀀스 모델링
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Double Sparsity: H-Sparsity와 T-Sparsity의 이중 최적화
- 현재글 : [SGLang] Hybrid Attention: Dense-Sparse 동적 전환 전략
- 다음글 [SGLang] Triton Attention 커널: Python으로 작성하는 GPU 커널
댓글