[SGLang] GDN (Gated Diagonal Net): 게이트 기반 선형 어텐션
들어가며
Gated Delta Network(GDN)은 Delta Rule 기반의 선형 어텐션 메커니즘이다. 전통적 어텐션의 O(n²) Softmax 연산 대신 게이트를 활용한 재귀적 상태 업데이트로 O(n) 복잡도를 달성한다. SGLang의 GDNAttnBackend는 GDN을 서빙 환경에 최적화하여, Triton, FlashInfer, CuTe DSL 세 가지 커널 백엔드를 모드별로 조합한다.
이 글에서는 python/sglang/srt/layers/attention/linear/gdn_backend.py를 분석한다.
구조도
┌─────────────────────────────────────────────────────────┐
│ GDNAttnBackend │
│ │
│ ┌──────────────────────────────────────────────┐ │
│ │ GDNKernelDispatcher │ │
│ │ ┌─────────┐ ┌──────────┐ ┌──────────────┐ │ │
│ │ │ decode │ │ extend │ │ verify │ │ │
│ │ │ _kernel │ │ _kernel │ │ _kernel │ │ │
│ │ └────┬────┘ └────┬─────┘ └──────┬───────┘ │ │
│ │ │ │ │ │ │
│ │ ┌────▼────┐ ┌────▼────┐ ┌─────▼────┐ │ │
│ │ │Triton / │ │ Triton │ │FlashInfer│ │ │
│ │ │FlashInf/│ │ or │ │ or │ │ │
│ │ │CuteDSL │ │FlashInf │ │ Triton │ │ │
│ │ └─────────┘ └─────────┘ └──────────┘ │ │
│ └──────────────────────────────────────────────┘ │
│ │
│ forward_decode() → causal_conv1d_update + kernel.decode│
│ forward_extend() → causal_conv1d_fn + fused_gdn_gating│
│ + kernel.extend │
└─────────────────────────────────────────────────────────┘
핵심 코드 분석
GDNKernelDispatcher: 모드별 커널 선택
class GDNKernelDispatcher:
def __init__(self, decode_backend, prefill_backend):
triton_kernel = TritonGDNKernel()
if decode_backend.is_triton():
self.decode_kernel = triton_kernel
elif decode_backend.is_cutedsl():
from ...kernels.gdn_cutedsl import CuteDSLGDNKernel
self.decode_kernel = CuteDSLGDNKernel()
elif decode_backend.is_flashinfer():
from ...kernels.gdn_flashinfer import FlashInferGDNKernel
flashinfer_kernel = FlashInferGDNKernel()
self.decode_kernel = flashinfer_kernel
커널 디스패처는 Decode, Extend, Verify 각 모드에 최적의 커널을 배정한다. Decode에는 세 가지 커널(Triton, CuTe DSL, FlashInfer)이 모두 사용 가능하고, Extend에는 Triton과 FlashInfer만 지원된다. CuTe DSL은 Decode 전용이다. 이런 비대칭 설계는 각 커널의 강점이 다르기 때문이다.
Packed Decode: 분리 없는 Fused 커널
def forward_decode(self, layer, forward_batch, mixed_qkv, a, b, **kwargs):
mixed_qkv = causal_conv1d_update(
mixed_qkv, conv_states, layer.conv_weights,
layer.bias, layer.activation, conv_state_indices=cache_indices,
)
# Fused packed decode: split + reshape + gating을 하나로
if self.kernel_dispatcher.supports_packed_decode:
core_attn_out = self.kernel_dispatcher.packed_decode(
mixed_qkv=mixed_qkv, a=a, b=b,
A_log=layer.A_log, dt_bias=layer.dt_bias,
scale=layer.head_k_dim**-0.5,
ssm_states=ssm_states, cache_indices=cache_indices,
num_v_heads=layer.num_v_heads, head_v_dim=layer.head_v_dim,
)
return core_attn_out
GDN의 Decode 경로는 mixed_qkv 텐서를 Q, K, V로 분리하고, reshape한 뒤, 게이팅을 적용해야 한다. packed_decode가 지원되면 이 세 단계를 하나의 Fused 커널로 실행하여 메모리 왕복을 줄인다. 지원되지 않으면 torch.split으로 분리한 후 일반 decode 커널을 호출한다.
비 Packed Decode: 명시적 분리
# Packed decode가 미지원인 경우
query, key, value = torch.split(
mixed_qkv, [layer.q_dim, layer.k_dim, layer.v_dim], dim=-1,
)
bs = forward_batch.batch_size
query = query.view(1, bs, layer.num_q_heads, layer.head_q_dim)
key = key.view(1, bs, layer.num_k_heads, layer.head_k_dim)
value = value.view(1, bs, layer.num_v_heads, layer.head_v_dim)
core_attn_out = self.kernel_dispatcher.decode(
q=query, k=key, v=value, a=a, b=b,
A_log=layer.A_log, dt_bias=layer.dt_bias,
ssm_states=ssm_states, cache_indices=cache_indices,
query_start_loc=query_start_loc,
)
Packed decode가 미지원이면 mixed_qkv를 [q_dim, k_dim, v_dim]으로 split한 후 [1, bs, heads, head_dim] 형태로 reshape한다. 첫 번째 차원의 1은 시퀀스 길이(Decode는 항상 1)를 나타낸다.
Forward Extend: Fused GDN Gating
def forward_extend(self, layer, forward_batch, mixed_qkv, a, b, **kwargs):
mixed_qkv = causal_conv1d_fn(
mixed_qkv.transpose(0, 1),
layer.conv_weights, layer.bias,
activation=layer.activation,
conv_states=conv_states,
has_initial_state=has_initial_states,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
).transpose(0, 1)[:seq_len]
# Fused gating: A_log, a, b, dt_bias → g, beta
g, beta = fused_gdn_gating(layer.A_log, a, b, layer.dt_bias)
core_attn_out, last_recurrent_state, h = self.kernel_dispatcher.extend(
q=query, k=key, v=value, g=g, beta=beta,
ssm_states=ssm_states, cache_indices=cache_indices,
query_start_loc=query_start_loc,
)
Extend 경로에서는 먼저 Causal Convolution으로 시퀀스 전체를 처리하고, fused_gdn_gating으로 게이트 값(g)과 베타를 계산한다. fused_gdn_gating은 A_log, a, b, dt_bias를 입력받아 하나의 Triton 커널로 게이트 연산을 수행한다. 이후 extend 커널이 청크 기반 Gated Delta Rule로 실제 어텐션을 계산한다.
Target Verify: Speculative Decoding 지원
if is_target_verify:
mixed_qkv_reshaped = mixed_qkv.view(
batch_size, draft_token_num, -1
).transpose(1, 2)
mixed_qkv_processed = causal_conv1d_update(
mixed_qkv_reshaped, conv_states, layer.conv_weights,
layer.bias, layer.activation,
conv_state_indices=cache_indices[:batch_size],
intermediate_conv_window=intermediate_conv_window_cache,
intermediate_state_indices=intermediate_state_indices[:batch_size],
retrieve_next_token=retrieve_next_token,
retrieve_next_sibling=retrieve_next_sibling,
retrieve_parent_token=retrieve_parent_token,
)
Speculative Decoding의 Target Verify 모드에서는 드래프트 토큰을 한 번에 검증해야 한다. retrieve_next_token, retrieve_next_sibling, retrieve_parent_token은 트리 구조의 추측 토큰을 올바른 순서로 순회하기 위한 인덱스이다. 중간 상태(intermediate_conv_window_cache, intermediate_state_cache)를 캐싱하여 검증 실패 시 올바른 지점에서 재시작할 수 있다.
Conv State Shape 검증
class GDNAttnBackend(MambaAttnBackendBase):
def __init__(self, model_runner):
super().__init__(model_runner)
self.conv_states_shape = (
model_runner.req_to_token_pool.mamba_pool.mamba_cache.conv[0].shape
)
assert (
self.conv_states_shape[-1] < FLA_CHUNK_SIZE
), f"{self.conv_states_shape[-1]=} should be less than {FLA_CHUNK_SIZE}"
Conv 상태의 마지막 차원(kernel size)이 FLA의 CHUNK_SIZE(64)보다 작아야 한다. 이는 청크 기반 연산에서 Conv 상태가 청크 경계를 넘지 않도록 보장하는 불변식이다.
커널 백엔드 비교
| 커널 | Decode | Extend | Verify | 특징 |
|---|---|---|---|---|
| Triton | O | O | O | 범용, AMD 지원 |
| FlashInfer | O | O | O | CUDA 최적화, packed decode 지원 |
| CuTe DSL | O | X | X | CUDA Decode 전용, 고성능 |
GDNKernelDispatcher는 이 세 커널을 모드별로 독립적으로 조합한다. 예를 들어 Decode에 CuTe DSL, Extend에 FlashInfer를 동시에 사용할 수 있다. 로그에 선택된 조합을 출력하여 디버깅을 돕는다:
rank0_log(
f"GDN kernel dispatcher: decode={self.decode_kernel.__class__.__name__}, "
f"extend={self.extend_kernel.__class__.__name__}, "
f"verify={self.verify_kernel.__class__.__name__} "
f"packed_decode={self.supports_packed_decode}"
)
관련 포스트
- KDA (Kernel-Driven Attention): 커널 기반 선형 어텐션
- FLA (Flashy Linear Attention): 청크 기반 선형 어텐션 연산
- Mamba (SSM): 선형 시간 복잡도 시퀀스 모델링
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Mamba (SSM): 선형 시간 복잡도 시퀀스 모델링
- 현재글 : [SGLang] GDN (Gated Diagonal Net): 게이트 기반 선형 어텐션
- 다음글 [SGLang] KDA (Kernel-Driven Attention): 커널 기반 선형 어텐션
댓글