본문으로 건너뛰기

[SGLang] Attention Registry: 동적 백엔드 선택 메커니즘

들어가며

SGLang은 15개 이상의 어텐션 백엔드를 지원한다. FlashInfer, FlashAttention v3/v4, Triton, FlashMLA, CUTLASS MLA, NSA, Intel AMX, AMD AIter 등 하드웨어와 모델 아키텍처에 따라 최적의 백엔드가 달라진다. 이 다양한 백엔드를 런타임에 동적으로 선택하고 인스턴스화하는 것이 Attention Registry의 역할이다.

이 글에서는 python/sglang/srt/layers/attention/attention_registry.py를 분석한다.

전체 구조

Attention Registry의 백엔드 선택 흐름은 다음과 같다.

  ServerArgs.attention_backend = "flashinfer"
       │
       ▼
┌──────────────────────────────────────────┐
│         ATTENTION_BACKENDS (dict)         │
│                                          │
│  "flashinfer" → create_flashinfer_backend│
│  "triton"     → create_triton_backend    │
│  "fa3"        → create_fa3_backend       │
│  "fa4"        → create_fa4_backend       │
│  "flashmla"   → create_flashmla_backend  │
│  "cutlass_mla"→ create_cutlass_mla_...   │
│  "nsa"        → create_nsa_backend       │
│  "aiter"      → create_aiter_backend     │
│  "torch_native"→ create_torch_native_... │
│  "intel_amx"  → create_intel_amx_...     │
│  ...                                     │
└──────────┬───────────────────────────────┘
           │
           ▼
  create_flashinfer_backend(runner)
           │
           ├─ MLA 모델? ──▶ FlashInferMLAAttnBackend
           │
           └─ 일반 모델? ──▶ FlashInferAttnBackend
           │
           ▼
  attn_backend_wrapper(runner, backend)
           │
           ├─ Hybrid GDN?  ──▶ HybridLinearAttnBackend
           ├─ Mamba2?      ──▶ Mamba2AttnBackend
           └─ 일반 모델?    ──▶ backend (그대로 반환)

레지스트리 패턴: 데코레이터 기반 등록

Registry의 핵심은 ATTENTION_BACKENDS 딕셔너리와 register_attention_backend 데코레이터다.

ATTENTION_BACKENDS = {}

def register_attention_backend(name):
    def decorator(fn):
        ATTENTION_BACKENDS[name] = fn
        return fn
    return decorator

각 백엔드는 데코레이터로 자신을 등록한다. 이 패턴의 장점은 새로운 백엔드 추가 시 레지스트리 코드를 수정할 필요 없이 팩토리 함수만 작성하면 된다는 것이다.

백엔드별 팩토리 함수

FlashInfer: MLA 자동 감지

@register_attention_backend("flashinfer")
def create_flashinfer_backend(runner):
    if not runner.use_mla_backend:
        from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
        return FlashInferAttnBackend(
            runner, init_new_workspace=runner.init_new_workspace
        )
    else:
        from sglang.srt.layers.attention.flashinfer_mla_backend import (
            FlashInferMLAAttnBackend,
        )
        return FlashInferMLAAttnBackend(runner)

FlashInfer 백엔드는 runner.use_mla_backend 플래그로 MLA 모델 여부를 감지한다. DeepSeek-V2 같은 MLA 모델이면 FlashInferMLAAttnBackend를, 일반 모델이면 FlashInferAttnBackend를 생성한다. Speculative decoding(EAGLE) 시에는 별도의 CUDA stream을 초기화한다.

FlashAttention v3/v4: SM 아키텍처 검증

@register_attention_backend("fa3")
def create_flashattention_v3_backend(runner):
    assert (
        torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend
    ) or torch.cuda.get_device_capability()[0] == 9, (
        "FlashAttention v3 Backend requires SM>=80 and SM<=90."
    )
    from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
    return FlashAttentionBackend(runner)

@register_attention_backend("fa4")
def create_flashattention_v4_backend(runner):
    from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
    return FlashAttentionBackend(runner, fa_impl_ver=4)

FA3은 SM80~SM90(A100, H100)을 요구하며, MLA가 아닌 모델에서만 SM80을 지원한다. FA4는 fa_impl_ver=4로 동일한 FlashAttentionBackend 클래스를 재사용한다.

Triton: Double Sparsity 분기

@register_attention_backend("triton")
def create_triton_backend(runner):
    assert not runner.model_config.is_encoder_decoder
    if runner.server_args.enable_double_sparsity:
        from sglang.srt.layers.attention.double_sparsity_backend import (
            DoubleSparseAttnBackend,
        )
        return DoubleSparseAttnBackend(runner)
    else:
        from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
        return TritonAttnBackend(runner)

Triton 백엔드는 enable_double_sparsity 옵션에 따라 일반 Triton 또는 Double Sparsity 백엔드로 분기한다. Cross attention(encoder-decoder 모델)은 지원하지 않는다.

하드웨어 전용 백엔드

@register_attention_backend("aiter")    # AMD GPU
def create_aiter_backend(runner): ...

@register_attention_backend("ascend")   # Huawei NPU
def create_ascend_backend(runner): ...

@register_attention_backend("intel_amx") # Intel CPU
def create_intel_amx_backend(runner): ...

@register_attention_backend("intel_xpu") # Intel GPU
def create_intel_xpu_backend(runner): ...

attn_backend_wrapper: 하이브리드 모델 처리

일반 어텐션 백엔드를 생성한 뒤, attn_backend_wrapper가 하이브리드 모델을 위한 래핑을 수행한다.

def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBackend"):
    if cfg := runner.mambaish_config:
        check_environments()
        initialize_linear_attn_config(runner.server_args)
        if runner.hybrid_gdn_config is not None:
            linear_attn_backend = GDNAttnBackend(runner)
        elif runner.mamba2_config is not None:
            linear_attn_backend = Mamba2AttnBackend(runner)
        elif runner.kimi_linear_config is not None:
            linear_attn_backend = KDAAttnBackend(runner)
        elif runner.hybrid_lightning_config is not None:
            linear_attn_backend = LightningAttentionBackend(runner)
        # ...
        full_attn_layers = cfg.full_attention_layer_ids
        return HybridLinearAttnBackend(
            full_attn_backend, linear_attn_backend, full_attn_layers
        )
    return full_attn_backend

Jamba, NemotronH 같은 하이브리드 모델은 일부 레이어에서 Mamba2나 Linear Attention을 사용하고, 나머지 레이어에서 일반 어텐션을 사용한다. HybridLinearAttnBackendfull_attention_layer_ids를 기준으로 레이어별로 적절한 백엔드를 디스패치한다.

등록된 백엔드 전체 목록

이름 대상 하드웨어 모델 제약
flashinfer NVIDIA GPU MLA 자동 감지
fa3 SM80-SM90 MLA 미지원(SM80)
fa4 SM90+ 범용
triton NVIDIA GPU encoder-decoder 미지원
flashmla NVIDIA GPU MLA 전용
cutlass_mla NVIDIA GPU MLA 전용
nsa NVIDIA GPU DeepSeek NSA 전용
trtllm_mla NVIDIA GPU MLA 전용
trtllm_mha NVIDIA GPU MHA 전용
aiter AMD GPU -
ascend Huawei NPU -
intel_amx Intel CPU -
intel_xpu Intel GPU -
torch_native 범용 참조 구현
flex_attention 범용 torch.compile 기반
wave - 실험적
dual_chunk_flash_attn NVIDIA GPU Dual Chunk 모델

설계 근거: Lazy Import

모든 팩토리 함수는 백엔드 클래스를 함수 내부에서 import한다. 이는 의도적인 설계다.

@register_attention_backend("flashinfer")
def create_flashinfer_backend(runner):
    import torch  # 함수 내부 import
    if not runner.use_mla_backend:
        from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend

이유는 CUDA context 초기화 시점을 제어하기 위해서다. FlashInfer, CUTLASS 등의 모듈은 import 시 CUDA context를 초기화한다. 서버 시작 시 모든 백엔드를 import하면 불필요한 GPU 메모리 할당과 초기화 오버헤드가 발생한다. Lazy import로 실제 사용되는 백엔드만 초기화한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글