본문으로 건너뛰기

[SGLang] GDN (Gated Diagonal Net): 게이트 기반 선형 어텐션

들어가며

Gated Delta Network(GDN)은 Delta Rule 기반의 선형 어텐션 메커니즘이다. 전통적 어텐션의 O(n²) Softmax 연산 대신 게이트를 활용한 재귀적 상태 업데이트로 O(n) 복잡도를 달성한다. SGLang의 GDNAttnBackend는 GDN을 서빙 환경에 최적화하여, Triton, FlashInfer, CuTe DSL 세 가지 커널 백엔드를 모드별로 조합한다.

이 글에서는 python/sglang/srt/layers/attention/linear/gdn_backend.py를 분석한다.

구조도

┌─────────────────────────────────────────────────────────┐
│                    GDNAttnBackend                        │
│                                                         │
│  ┌──────────────────────────────────────────────┐       │
│  │           GDNKernelDispatcher                │       │
│  │  ┌─────────┐ ┌──────────┐ ┌──────────────┐  │       │
│  │  │ decode  │ │ extend   │ │   verify     │  │       │
│  │  │ _kernel │ │ _kernel  │ │   _kernel    │  │       │
│  │  └────┬────┘ └────┬─────┘ └──────┬───────┘  │       │
│  │       │           │              │           │       │
│  │  ┌────▼────┐ ┌────▼────┐  ┌─────▼────┐     │       │
│  │  │Triton / │ │ Triton  │  │FlashInfer│     │       │
│  │  │FlashInf/│ │   or    │  │  or      │     │       │
│  │  │CuteDSL │ │FlashInf │  │ Triton   │     │       │
│  │  └─────────┘ └─────────┘  └──────────┘     │       │
│  └──────────────────────────────────────────────┘       │
│                                                         │
│  forward_decode() → causal_conv1d_update + kernel.decode│
│  forward_extend() → causal_conv1d_fn + fused_gdn_gating│
│                     + kernel.extend                     │
└─────────────────────────────────────────────────────────┘

핵심 코드 분석

GDNKernelDispatcher: 모드별 커널 선택

class GDNKernelDispatcher:
    def __init__(self, decode_backend, prefill_backend):
        triton_kernel = TritonGDNKernel()

        if decode_backend.is_triton():
            self.decode_kernel = triton_kernel
        elif decode_backend.is_cutedsl():
            from ...kernels.gdn_cutedsl import CuteDSLGDNKernel
            self.decode_kernel = CuteDSLGDNKernel()
        elif decode_backend.is_flashinfer():
            from ...kernels.gdn_flashinfer import FlashInferGDNKernel
            flashinfer_kernel = FlashInferGDNKernel()
            self.decode_kernel = flashinfer_kernel

커널 디스패처는 Decode, Extend, Verify 각 모드에 최적의 커널을 배정한다. Decode에는 세 가지 커널(Triton, CuTe DSL, FlashInfer)이 모두 사용 가능하고, Extend에는 Triton과 FlashInfer만 지원된다. CuTe DSL은 Decode 전용이다. 이런 비대칭 설계는 각 커널의 강점이 다르기 때문이다.

Packed Decode: 분리 없는 Fused 커널

def forward_decode(self, layer, forward_batch, mixed_qkv, a, b, **kwargs):
    mixed_qkv = causal_conv1d_update(
        mixed_qkv, conv_states, layer.conv_weights,
        layer.bias, layer.activation, conv_state_indices=cache_indices,
    )

    # Fused packed decode: split + reshape + gating을 하나로
    if self.kernel_dispatcher.supports_packed_decode:
        core_attn_out = self.kernel_dispatcher.packed_decode(
            mixed_qkv=mixed_qkv, a=a, b=b,
            A_log=layer.A_log, dt_bias=layer.dt_bias,
            scale=layer.head_k_dim**-0.5,
            ssm_states=ssm_states, cache_indices=cache_indices,
            num_v_heads=layer.num_v_heads, head_v_dim=layer.head_v_dim,
        )
        return core_attn_out

GDN의 Decode 경로는 mixed_qkv 텐서를 Q, K, V로 분리하고, reshape한 뒤, 게이팅을 적용해야 한다. packed_decode가 지원되면 이 세 단계를 하나의 Fused 커널로 실행하여 메모리 왕복을 줄인다. 지원되지 않으면 torch.split으로 분리한 후 일반 decode 커널을 호출한다.

비 Packed Decode: 명시적 분리

    # Packed decode가 미지원인 경우
    query, key, value = torch.split(
        mixed_qkv, [layer.q_dim, layer.k_dim, layer.v_dim], dim=-1,
    )
    bs = forward_batch.batch_size
    query = query.view(1, bs, layer.num_q_heads, layer.head_q_dim)
    key = key.view(1, bs, layer.num_k_heads, layer.head_k_dim)
    value = value.view(1, bs, layer.num_v_heads, layer.head_v_dim)

    core_attn_out = self.kernel_dispatcher.decode(
        q=query, k=key, v=value, a=a, b=b,
        A_log=layer.A_log, dt_bias=layer.dt_bias,
        ssm_states=ssm_states, cache_indices=cache_indices,
        query_start_loc=query_start_loc,
    )

Packed decode가 미지원이면 mixed_qkv[q_dim, k_dim, v_dim]으로 split한 후 [1, bs, heads, head_dim] 형태로 reshape한다. 첫 번째 차원의 1은 시퀀스 길이(Decode는 항상 1)를 나타낸다.

Forward Extend: Fused GDN Gating

def forward_extend(self, layer, forward_batch, mixed_qkv, a, b, **kwargs):
    mixed_qkv = causal_conv1d_fn(
        mixed_qkv.transpose(0, 1),
        layer.conv_weights, layer.bias,
        activation=layer.activation,
        conv_states=conv_states,
        has_initial_state=has_initial_states,
        cache_indices=cache_indices,
        query_start_loc=query_start_loc,
    ).transpose(0, 1)[:seq_len]

    # Fused gating: A_log, a, b, dt_bias → g, beta
    g, beta = fused_gdn_gating(layer.A_log, a, b, layer.dt_bias)

    core_attn_out, last_recurrent_state, h = self.kernel_dispatcher.extend(
        q=query, k=key, v=value, g=g, beta=beta,
        ssm_states=ssm_states, cache_indices=cache_indices,
        query_start_loc=query_start_loc,
    )

Extend 경로에서는 먼저 Causal Convolution으로 시퀀스 전체를 처리하고, fused_gdn_gating으로 게이트 값(g)과 베타를 계산한다. fused_gdn_gatingA_log, a, b, dt_bias를 입력받아 하나의 Triton 커널로 게이트 연산을 수행한다. 이후 extend 커널이 청크 기반 Gated Delta Rule로 실제 어텐션을 계산한다.

Target Verify: Speculative Decoding 지원

if is_target_verify:
    mixed_qkv_reshaped = mixed_qkv.view(
        batch_size, draft_token_num, -1
    ).transpose(1, 2)
    mixed_qkv_processed = causal_conv1d_update(
        mixed_qkv_reshaped, conv_states, layer.conv_weights,
        layer.bias, layer.activation,
        conv_state_indices=cache_indices[:batch_size],
        intermediate_conv_window=intermediate_conv_window_cache,
        intermediate_state_indices=intermediate_state_indices[:batch_size],
        retrieve_next_token=retrieve_next_token,
        retrieve_next_sibling=retrieve_next_sibling,
        retrieve_parent_token=retrieve_parent_token,
    )

Speculative Decoding의 Target Verify 모드에서는 드래프트 토큰을 한 번에 검증해야 한다. retrieve_next_token, retrieve_next_sibling, retrieve_parent_token은 트리 구조의 추측 토큰을 올바른 순서로 순회하기 위한 인덱스이다. 중간 상태(intermediate_conv_window_cache, intermediate_state_cache)를 캐싱하여 검증 실패 시 올바른 지점에서 재시작할 수 있다.

Conv State Shape 검증

class GDNAttnBackend(MambaAttnBackendBase):
    def __init__(self, model_runner):
        super().__init__(model_runner)
        self.conv_states_shape = (
            model_runner.req_to_token_pool.mamba_pool.mamba_cache.conv[0].shape
        )
        assert (
            self.conv_states_shape[-1] < FLA_CHUNK_SIZE
        ), f"{self.conv_states_shape[-1]=} should be less than {FLA_CHUNK_SIZE}"

Conv 상태의 마지막 차원(kernel size)이 FLA의 CHUNK_SIZE(64)보다 작아야 한다. 이는 청크 기반 연산에서 Conv 상태가 청크 경계를 넘지 않도록 보장하는 불변식이다.

커널 백엔드 비교

커널 Decode Extend Verify 특징
Triton O O O 범용, AMD 지원
FlashInfer O O O CUDA 최적화, packed decode 지원
CuTe DSL O X X CUDA Decode 전용, 고성능

GDNKernelDispatcher는 이 세 커널을 모드별로 독립적으로 조합한다. 예를 들어 Decode에 CuTe DSL, Extend에 FlashInfer를 동시에 사용할 수 있다. 로그에 선택된 조합을 출력하여 디버깅을 돕는다:

rank0_log(
    f"GDN kernel dispatcher: decode={self.decode_kernel.__class__.__name__}, "
    f"extend={self.extend_kernel.__class__.__name__}, "
    f"verify={self.verify_kernel.__class__.__name__} "
    f"packed_decode={self.supports_packed_decode}"
)

관련 포스트

  • KDA (Kernel-Driven Attention): 커널 기반 선형 어텐션
  • FLA (Flashy Linear Attention): 청크 기반 선형 어텐션 연산
  • Mamba (SSM): 선형 시간 복잡도 시퀀스 모델링

참고

댓글

관련 포스트

SGLang 의 다른글