본문으로 건너뛰기

[SGLang] Hybrid Attention: Dense-Sparse 동적 전환 전략

들어가며

LLM 서빙에서 Prefill과 Decode는 연산 특성이 완전히 다르다. Prefill은 긴 시퀀스를 한 번에 처리하므로 compute-bound이고, Decode는 토큰 하나씩 생성하므로 memory-bound이다. 같은 어텐션 커널을 양쪽에 쓰면 한쪽에서 반드시 비효율이 발생한다.

SGLang의 HybridAttnBackend는 이 문제를 해결한다. Prefill 단계에는 Dense 어텐션 백엔드를, Decode 단계에는 Sparse 어텐션 백엔드를 동적으로 전환하여 각 단계에 최적화된 커널을 사용한다. 이 글에서는 python/sglang/srt/layers/attention/hybrid_attn_backend.py를 분석한다.

구조도

┌─────────────────────────────────────┐
│         HybridAttnBackend           │
│  ┌───────────┐   ┌───────────────┐  │
│  │  prefill  │   │    decode     │  │
│  │  backend  │   │    backend    │  │
│  └─────┬─────┘   └──────┬────────┘  │
│        │                │           │
│  ┌─────▼────────────────▼────────┐  │
│  │      _select_backend()        │  │
│  │  ForwardMode 기반 동적 선택   │  │
│  └───────────────────────────────┘  │
└─────────────────────────────────────┘
         │
         ▼
  ┌──────────────────────────────┐
  │       ForwardMode 분기       │
  ├──────────────────────────────┤
  │ decode_or_idle → decode_backend  │
  │ target_verify  → spec_mode 판단  │
  │ draft_extend   → spec_mode 판단  │
  │ prefill/extend → prefill_backend │
  └──────────────────────────────┘

HybridAttnBackend는 두 개의 실제 백엔드를 내부에 보유하고, ForwardMode에 따라 적절한 것을 선택한다.

핵심 코드 분석

생성자: 두 백엔드를 주입받는 구조

class HybridAttnBackend(AttentionBackend):
    def __init__(
        self,
        model_runner: ModelRunner,
        prefill_backend: AttentionBackend,
        decode_backend: AttentionBackend,
    ):
        self.model_runner = model_runner
        self.prefill_backend = prefill_backend
        self.decode_backend = decode_backend
        self.data_type = model_runner.kv_cache_dtype

생성 시점에 Prefill용과 Decode용 백엔드를 각각 주입받는다. 예를 들어 Prefill에는 FlashAttention, Decode에는 Triton Flash Decoding을 조합할 수 있다. 이 조합은 HybridAttnBackend를 사용하는 상위 코드에서 결정된다.

_select_backend: ForwardMode 기반 동적 전환

def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
    if forward_mode.is_decode_or_idle():
        return self.decode_backend
    elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
        return (
            self.decode_backend
            if self.model_runner.server_args.speculative_attention_mode == "decode"
            else self.prefill_backend
        )
    else:
        return self.prefill_backend

전환 로직의 핵심이다. 세 가지 경로로 분기한다.

  1. Decode/Idle: 항상 decode_backend 사용
  2. Target Verify / Draft Extend: Speculative Decoding 설정에 따라 결정. speculative_attention_mode"decode"면 decode_backend, 아니면 prefill_backend를 사용
  3. Prefill/Extend: 항상 prefill_backend 사용

Speculative Decoding에서 Target Verify는 여러 토큰을 한 번에 검증하므로, Prefill처럼 처리하는 것이 기본 전략이다. 하지만 토큰 수가 적으면 Decode 커널이 더 효율적일 수 있어 사용자가 선택할 수 있도록 설정을 열어둔다.

forward: 선형 어텐션까지 지원하는 범용 위임

def forward(
    self,
    q: Optional[torch.Tensor] = None,
    k: Optional[torch.Tensor] = None,
    v: Optional[torch.Tensor] = None,
    layer: Optional[RadixAttention] = None,
    forward_batch: Optional[ForwardBatch] = None,
    save_kv_cache: bool = True,
    *,
    mixed_qkv: Optional[torch.Tensor] = None,  # For linear attention
    a: Optional[torch.Tensor] = None,
    b: Optional[torch.Tensor] = None,
    **kwargs,
):
    backend = self._select_backend(forward_batch.forward_mode)
    if mixed_qkv is not None:
        return backend.forward(
            layer=layer, forward_batch=forward_batch,
            save_kv_cache=save_kv_cache,
            mixed_qkv=mixed_qkv, a=a, b=b, **kwargs,
        )
    return backend.forward(q, k, v, layer, forward_batch, save_kv_cache, **kwargs)

forward 메서드는 일반 어텐션(q, k, v)과 선형 어텐션(mixed_qkv, a, b)을 모두 처리한다. mixed_qkv가 전달되면 선형 어텐션 경로로, 아니면 일반 어텐션 경로로 위임한다. 이 설계로 하이브리드 모델(Transformer + Mamba 레이어 혼합)도 동일한 백엔드 인터페이스로 서빙할 수 있다.

CUDA Graph 초기화: 양쪽 백엔드를 동시에 준비

def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
    self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
    if (
        self.model_runner.server_args.speculative_algorithm is not None
        and self.model_runner.server_args.speculative_attention_mode == "prefill"
    ):
        self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)

CUDA Graph는 커널 호출을 캡처하여 재실행하는 최적화 기법이다. Decode 백엔드는 항상 초기화한다. Prefill 백엔드는 Speculative Decoding이 활성화되고 prefill 모드일 때만 초기화한다. 이는 Target Verify가 CUDA Graph로 실행될 수 있기 때문이다.

forward_decode / forward_extend: 직접 위임

def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs):
    return self.decode_backend.forward_decode(
        q, k, v, layer, forward_batch, save_kv_cache, **kwargs
    )

def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs):
    backend = self._select_backend(forward_batch.forward_mode)
    return backend.forward_extend(
        q, k, v, layer, forward_batch, save_kv_cache, **kwargs
    )

forward_decode는 항상 decode_backend에 직접 위임한다. forward_extend_select_backend로 선택한 후 위임한다. Target Verify 모드에서 extend가 호출될 수 있으므로, extend에서도 동적 전환이 필요하다.

설계 근거: 왜 런타임 전환인가

접근법 장점 단점
단일 백엔드 단순한 구조 Prefill/Decode 중 하나에서 비효율
컴파일 타임 분기 오버헤드 없음 백엔드 조합 변경 시 재빌드 필요
런타임 전환 유연한 조합, 설정으로 변경 분기 오버헤드 (무시 가능)

SGLang은 런타임 전환을 선택했다. _select_backend의 분기 비용은 Python if 문 하나이므로, GPU 커널 실행 시간 대비 완전히 무시할 수 있다. 대신 FlashInfer + Triton, FlashAttention + Custom Decode 등 다양한 조합을 서버 시작 시 설정만으로 바꿀 수 있다.

관련 포스트

  • Triton Attention 커널: Python으로 작성하는 GPU 커널
  • Mamba (SSM): 선형 시간 복잡도 시퀀스 모델링

참고

댓글

관련 포스트

SGLang 의 다른글