본문으로 건너뛰기

[SGLang] RadixAttention Layer: 통합 어텐션 인터페이스의 설계

들어가며

LLM 추론 시스템에서 어텐션 연산은 가장 성능이 중요한 컴포넌트다. FlashAttention, FlashInfer, Triton 등 다양한 백엔드가 존재하며, 각각 하드웨어와 워크로드에 따라 최적의 선택이 달라진다. SGLang의 RadixAttention은 이 모든 백엔드를 하나의 통합 인터페이스 뒤에 숨기는 어텐션 레이어다.

이 글에서는 python/sglang/srt/layers/radix_attention.py를 중심으로 RadixAttention의 설계를 분석한다.

전체 구조

RadixAttention이 어텐션 백엔드와 상호작용하는 전체 흐름은 다음과 같다.

  Model Layer (e.g. DeepseekV2Attention)
       │
       │  q, k, v, forward_batch
       ▼
┌──────────────────────────────────────────┐
│           RadixAttention.forward()        │
│                                          │
│  ┌─ Extend Mode (torch.compile) ──────┐  │
│  │  unified_attention_with_output()    │  │
│  │       │                             │  │
│  │       ▼                             │  │
│  │  get_forward_context()              │  │
│  │       │                             │  │
│  │       ▼                             │  │
│  │  attn_backend.forward()             │  │
│  └─────────────────────────────────────┘  │
│                                          │
│  ┌─ Decode Mode ──────────────────────┐  │
│  │  forward_batch.attn_backend.forward │  │
│  └─────────────────────────────────────┘  │
│                                          │
│          ┌──────────┬──────────┐          │
│          ▼          ▼          ▼          │
│     FlashInfer  FlashAttn   Triton       │
│      Backend     Backend    Backend      │
└──────────────────────────────────────────┘

AttentionType: 어텐션 모드 분류

RadixAttention은 세 가지 어텐션 타입을 지원한다.

class AttentionType(Enum):
    # Decoder attention between previous layer Q/K/V
    DECODER = "decoder"
    # Decoder bidirectional attention between image tokens
    DECODER_BIDIRECTIONAL = "decoder_bidirectional"
    # Encoder attention between previous layer Q/K/V
    ENCODER_ONLY = "encoder_only"

DECODER는 일반적인 causal attention, ENCODER_ONLY는 양방향 어텐션이다. DECODER_BIDIRECTIONAL은 멀티모달 모델에서 이미지 토큰 간 양방향 어텐션에 사용된다. 이 타입은 문자열 기반 Enum으로 설계되어 torch.compile과의 호환성을 보장한다.

RadixAttention 초기화

RadixAttention.__init__은 레이어별 어텐션 파라미터를 설정한다.

class RadixAttention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        head_dim: int,
        scaling: float,
        num_kv_heads: int,
        layer_id: int,
        logit_cap: float = 0.0,
        v_head_dim: int = -1,
        sliding_window_size: int = -1,
        is_cross_attention: bool = False,
        pos_encoding_mode: str = "NONE",
        quant_config: Optional[QuantizationConfig] = None,
        attn_type: AttentionType = AttentionType.DECODER,
        use_irope: bool = False,
        prefix: str = "",
    ):
        super().__init__()
        self.tp_q_head_num = num_heads
        self.tp_k_head_num = num_kv_heads
        self.tp_v_head_num = num_kv_heads
        self.head_dim = head_dim
        self.qk_head_dim = head_dim
        self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim

주목할 점은 v_head_dim을 별도로 관리한다는 것이다. MLA(Multi-head Latent Attention)처럼 Q/K와 V의 head dimension이 다른 아키텍처를 지원하기 위한 설계다. layer_id는 KV 캐시 인덱싱과 레이어별 양자화에 사용된다.

forward: 백엔드 디스패치의 핵심

forward 메서드는 두 가지 경로로 분기한다.

def forward(
    self, q, k, v,
    forward_batch: ForwardBatch,
    save_kv_cache: bool = True,
    **kwargs,
):
    if k is not None:
        assert v is not None
        if "k_rope" not in kwargs:
            k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
            v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
        else:
            k = k.view(-1, self.tp_k_head_num, self.v_head_dim)

    if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
        # torch.compile 호환 경로
        output = torch.empty_like(q)
        unified_attention_with_output(
            q, k, v, output, save_kv_cache, self.layer_id, **kwargs
        )
        return output
    else:
        # 일반 경로
        return forward_batch.attn_backend.forward(
            q, k, v, self, forward_batch, save_kv_cache, **kwargs
        )

Extend 모드 + torch.compile 경로: get_forward_context()가 존재하면 unified_attention_with_output을 호출한다. 이 함수는 @register_custom_op으로 등록되어 torch.compile의 graph capture에 포함된다.

Decode 모드 / 일반 경로: forward_batch.attn_backend.forward()를 직접 호출한다. attn_backend는 런타임에 결정된 백엔드 인스턴스다.

unified_attention_with_output: torch.compile 지원

@register_custom_op(mutates_args=["output"])
@register_split_op()
def unified_attention_with_output(
    query, key, value, output, save_kv_cache, layer_id,
    *, q_rope=None, k_rope=None, sinks=None,
) -> None:
    context = get_forward_context()
    forward_batch = context.forward_batch
    attention_layers = context.attention_layers
    attention_layer = attention_layers[layer_id]

    ret = forward_batch.attn_backend.forward(
        query, key, value, attention_layer,
        forward_batch, save_kv_cache, **kwargs
    )
    output.view(ret.shape).copy_(ret)

이 함수의 핵심 설계 결정 두 가지가 있다.

첫째, @register_custom_op(mutates_args=["output"])을 사용한다. torch.compile에서 in-place mutation을 추적하기 위해 output 텐서를 명시적으로 변경 대상으로 등록한다.

둘째, @register_split_op()로 piecewise CUDA graph에서 이 연산을 분리 지점으로 표시한다. 어텐션 연산은 배치 크기에 따라 CUDA graph 캡처와 리플레이 전략이 달라지므로, 독립적으로 처리되어야 한다.

KV 캐시 연동: cross-layer sharing

forward에서 k is not None 조건을 확인하는 이유가 있다.

if k is not None:
    # For cross-layer sharing, kv can be None
    assert v is not None

Cross-layer KV sharing은 여러 어텐션 레이어가 동일한 KV 캐시를 공유하는 기법이다. 이 경우 하위 레이어에서 이미 KV를 캐시에 저장했으므로, 상위 레이어는 k=None, v=None으로 forward를 호출하고 백엔드가 캐시에서 직접 읽는다.

설계 근거: 왜 RadixAttention인가

RadixAttention이라는 이름은 SGLang의 RadixTree 기반 prefix caching에서 유래한다. 이 레이어는 단순한 어텐션 추상화를 넘어 다음을 가능하게 한다.

관심사 RadixAttention의 역할
백엔드 선택 forward_batch.attn_backend로 런타임 디스패치
KV 캐시 token_to_kv_pool과의 연동 자동화
컴파일 unified_attention_with_output으로 torch.compile 호환
양자화 quant_config로 레이어별 KV 양자화 지원
MLA v_head_dim, q_rope, k_rope로 이종 head dim 지원

모델 코드는 RadixAttention의 forward만 호출하면 된다. 어떤 백엔드가 어떻게 어텐션을 계산하는지는 완전히 캡슐화되어 있다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글