[SGLang] Attention Registry: 동적 백엔드 선택 메커니즘
들어가며
SGLang은 15개 이상의 어텐션 백엔드를 지원한다. FlashInfer, FlashAttention v3/v4, Triton, FlashMLA, CUTLASS MLA, NSA, Intel AMX, AMD AIter 등 하드웨어와 모델 아키텍처에 따라 최적의 백엔드가 달라진다. 이 다양한 백엔드를 런타임에 동적으로 선택하고 인스턴스화하는 것이 Attention Registry의 역할이다.
이 글에서는 python/sglang/srt/layers/attention/attention_registry.py를 분석한다.
전체 구조
Attention Registry의 백엔드 선택 흐름은 다음과 같다.
ServerArgs.attention_backend = "flashinfer"
│
▼
┌──────────────────────────────────────────┐
│ ATTENTION_BACKENDS (dict) │
│ │
│ "flashinfer" → create_flashinfer_backend│
│ "triton" → create_triton_backend │
│ "fa3" → create_fa3_backend │
│ "fa4" → create_fa4_backend │
│ "flashmla" → create_flashmla_backend │
│ "cutlass_mla"→ create_cutlass_mla_... │
│ "nsa" → create_nsa_backend │
│ "aiter" → create_aiter_backend │
│ "torch_native"→ create_torch_native_... │
│ "intel_amx" → create_intel_amx_... │
│ ... │
└──────────┬───────────────────────────────┘
│
▼
create_flashinfer_backend(runner)
│
├─ MLA 모델? ──▶ FlashInferMLAAttnBackend
│
└─ 일반 모델? ──▶ FlashInferAttnBackend
│
▼
attn_backend_wrapper(runner, backend)
│
├─ Hybrid GDN? ──▶ HybridLinearAttnBackend
├─ Mamba2? ──▶ Mamba2AttnBackend
└─ 일반 모델? ──▶ backend (그대로 반환)
레지스트리 패턴: 데코레이터 기반 등록
Registry의 핵심은 ATTENTION_BACKENDS 딕셔너리와 register_attention_backend 데코레이터다.
ATTENTION_BACKENDS = {}
def register_attention_backend(name):
def decorator(fn):
ATTENTION_BACKENDS[name] = fn
return fn
return decorator
각 백엔드는 데코레이터로 자신을 등록한다. 이 패턴의 장점은 새로운 백엔드 추가 시 레지스트리 코드를 수정할 필요 없이 팩토리 함수만 작성하면 된다는 것이다.
백엔드별 팩토리 함수
FlashInfer: MLA 자동 감지
@register_attention_backend("flashinfer")
def create_flashinfer_backend(runner):
if not runner.use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
return FlashInferAttnBackend(
runner, init_new_workspace=runner.init_new_workspace
)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)
return FlashInferMLAAttnBackend(runner)
FlashInfer 백엔드는 runner.use_mla_backend 플래그로 MLA 모델 여부를 감지한다. DeepSeek-V2 같은 MLA 모델이면 FlashInferMLAAttnBackend를, 일반 모델이면 FlashInferAttnBackend를 생성한다. Speculative decoding(EAGLE) 시에는 별도의 CUDA stream을 초기화한다.
FlashAttention v3/v4: SM 아키텍처 검증
@register_attention_backend("fa3")
def create_flashattention_v3_backend(runner):
assert (
torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend
) or torch.cuda.get_device_capability()[0] == 9, (
"FlashAttention v3 Backend requires SM>=80 and SM<=90."
)
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
return FlashAttentionBackend(runner)
@register_attention_backend("fa4")
def create_flashattention_v4_backend(runner):
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
return FlashAttentionBackend(runner, fa_impl_ver=4)
FA3은 SM80~SM90(A100, H100)을 요구하며, MLA가 아닌 모델에서만 SM80을 지원한다. FA4는 fa_impl_ver=4로 동일한 FlashAttentionBackend 클래스를 재사용한다.
Triton: Double Sparsity 분기
@register_attention_backend("triton")
def create_triton_backend(runner):
assert not runner.model_config.is_encoder_decoder
if runner.server_args.enable_double_sparsity:
from sglang.srt.layers.attention.double_sparsity_backend import (
DoubleSparseAttnBackend,
)
return DoubleSparseAttnBackend(runner)
else:
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
return TritonAttnBackend(runner)
Triton 백엔드는 enable_double_sparsity 옵션에 따라 일반 Triton 또는 Double Sparsity 백엔드로 분기한다. Cross attention(encoder-decoder 모델)은 지원하지 않는다.
하드웨어 전용 백엔드
@register_attention_backend("aiter") # AMD GPU
def create_aiter_backend(runner): ...
@register_attention_backend("ascend") # Huawei NPU
def create_ascend_backend(runner): ...
@register_attention_backend("intel_amx") # Intel CPU
def create_intel_amx_backend(runner): ...
@register_attention_backend("intel_xpu") # Intel GPU
def create_intel_xpu_backend(runner): ...
attn_backend_wrapper: 하이브리드 모델 처리
일반 어텐션 백엔드를 생성한 뒤, attn_backend_wrapper가 하이브리드 모델을 위한 래핑을 수행한다.
def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBackend"):
if cfg := runner.mambaish_config:
check_environments()
initialize_linear_attn_config(runner.server_args)
if runner.hybrid_gdn_config is not None:
linear_attn_backend = GDNAttnBackend(runner)
elif runner.mamba2_config is not None:
linear_attn_backend = Mamba2AttnBackend(runner)
elif runner.kimi_linear_config is not None:
linear_attn_backend = KDAAttnBackend(runner)
elif runner.hybrid_lightning_config is not None:
linear_attn_backend = LightningAttentionBackend(runner)
# ...
full_attn_layers = cfg.full_attention_layer_ids
return HybridLinearAttnBackend(
full_attn_backend, linear_attn_backend, full_attn_layers
)
return full_attn_backend
Jamba, NemotronH 같은 하이브리드 모델은 일부 레이어에서 Mamba2나 Linear Attention을 사용하고, 나머지 레이어에서 일반 어텐션을 사용한다. HybridLinearAttnBackend는 full_attention_layer_ids를 기준으로 레이어별로 적절한 백엔드를 디스패치한다.
등록된 백엔드 전체 목록
| 이름 | 대상 하드웨어 | 모델 제약 |
|---|---|---|
flashinfer |
NVIDIA GPU | MLA 자동 감지 |
fa3 |
SM80-SM90 | MLA 미지원(SM80) |
fa4 |
SM90+ | 범용 |
triton |
NVIDIA GPU | encoder-decoder 미지원 |
flashmla |
NVIDIA GPU | MLA 전용 |
cutlass_mla |
NVIDIA GPU | MLA 전용 |
nsa |
NVIDIA GPU | DeepSeek NSA 전용 |
trtllm_mla |
NVIDIA GPU | MLA 전용 |
trtllm_mha |
NVIDIA GPU | MHA 전용 |
aiter |
AMD GPU | - |
ascend |
Huawei NPU | - |
intel_amx |
Intel CPU | - |
intel_xpu |
Intel GPU | - |
torch_native |
범용 | 참조 구현 |
flex_attention |
범용 | torch.compile 기반 |
wave |
- | 실험적 |
dual_chunk_flash_attn |
NVIDIA GPU | Dual Chunk 모델 |
설계 근거: Lazy Import
모든 팩토리 함수는 백엔드 클래스를 함수 내부에서 import한다. 이는 의도적인 설계다.
@register_attention_backend("flashinfer")
def create_flashinfer_backend(runner):
import torch # 함수 내부 import
if not runner.use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
이유는 CUDA context 초기화 시점을 제어하기 위해서다. FlashInfer, CUTLASS 등의 모듈은 import 시 CUDA context를 초기화한다. 서버 시작 시 모든 백엔드를 import하면 불필요한 GPU 메모리 할당과 초기화 오버헤드가 발생한다. Lazy import로 실제 사용되는 백엔드만 초기화한다.
관련 포스트
참고
관련 포스트
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
- [sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서
- [sglang] NixlKVManager 성능 향상: 비동기 및 멀티스레드 KV 전송 도입
SGLang 의 다른글
- 이전글 [SGLang] RadixAttention Layer: 통합 어텐션 인터페이스의 설계
- 현재글 : [SGLang] Attention Registry: 동적 백엔드 선택 메커니즘
- 다음글 [SGLang] FlashAttention 백엔드: IO-aware 타일링 어텐션의 구현
댓글