[SGLang] KDA (Kernel-Driven Attention): 커널 기반 선형 어텐션
들어가며
KDA(Kimi Delta Attention)는 Moonshot AI의 Kimi 모델에서 사용하는 선형 어텐션 메커니즘이다. GDN과 유사하게 Delta Rule 기반이지만, 게이트가 head-wise 차원에 적용되는 구조적 차이가 있다. SGLang의 KDAAttnBackend는 Triton과 CuTe DSL 두 가지 커널을 지원하며, GDN 대비 더 간결한 구조를 가진다.
이 글에서는 python/sglang/srt/layers/attention/linear/kda_backend.py를 분석한다.
구조도
┌─────────────────────────────────────────────────┐
│ KDAAttnBackend │
│ │
│ ┌───────────────────────────────────────┐ │
│ │ KDAKernelDispatcher │ │
│ │ ┌──────────┐ ┌──────────────┐ │ │
│ │ │ decode │ │ extend │ │ │
│ │ │ _kernel │ │ _kernel │ │ │
│ │ └────┬─────┘ └──────┬───────┘ │ │
│ │ │ │ │ │
│ │ ┌────▼─────┐ ┌────▼─────┐ │ │
│ │ │ Triton / │ │ Triton │ │ │
│ │ │ CuteDSL │ │ (only) │ │ │
│ │ └──────────┘ └──────────┘ │ │
│ └───────────────────────────────────────┘ │
│ │
│ forward_decode() │
│ → causal_conv1d_update (Triton fallback) │
│ → split q,k,v → unflatten → decode kernel │
│ │
│ forward_extend() │
│ → split conv_weights per q,k,v │
│ → causal_conv1d_fn × 3 │
│ → extend kernel │
└─────────────────────────────────────────────────┘
핵심 코드 분석
KDAKernelDispatcher: GDN 대비 제한적 커널 조합
class KDAKernelDispatcher:
def __init__(self, decode_backend, prefill_backend):
triton_kernel = TritonKDAKernel()
if decode_backend.is_triton():
self.decode_kernel = triton_kernel
elif decode_backend.is_cutedsl():
if not is_cuda():
raise ValueError("KDA CuTe DSL backend requires CUDA")
from ...kernels.kda_cutedsl import CuteDSLKDAKernel
self.decode_kernel = CuteDSLKDAKernel()
else:
raise ValueError(
f"Unsupported KDA decode backend: {decode_backend}. "
"KDA currently only supports 'triton'."
)
if prefill_backend.is_triton():
self.extend_kernel = triton_kernel
else:
raise ValueError(
f"Unsupported KDA prefill backend: {prefill_backend}. "
"KDA currently only supports 'triton'."
)
GDN과 비교하면 KDA의 커널 지원은 제한적이다. Decode에 Triton과 CuTe DSL, Extend에는 Triton만 사용할 수 있다. FlashInfer 백엔드는 아직 KDA를 지원하지 않는다. 이는 KDA가 GDN 대비 후발 주자이기 때문이다.
KDA vs GDN: Causal Convolution의 차이
GDN은 mixed_qkv를 하나로 묶어서 Convolution하지만, KDA는 Q, K, V를 개별적으로 Convolution한다.
# KDA: Conv 가중치를 Q, K, V별로 분리
q_conv_weight, k_conv_weight, v_conv_weight = layer.conv_weights.split(
splits, dim=0
)
q_conv_state, k_conv_state, v_conv_state = conv_states.split(splits, dim=-2)
# Q, K, V 각각 별도 causal_conv1d_fn 호출
q = causal_conv1d_fn(
q, q_conv_weight, q_bias, activation="silu",
conv_states=q_conv_state,
has_initial_state=has_initial_state,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
k = causal_conv1d_fn(
k, k_conv_weight, k_bias, activation="silu",
conv_states=k_conv_state, ...
).transpose(0, 1)
v = causal_conv1d_fn(
v, v_conv_weight, v_bias, activation="silu",
conv_states=v_conv_state, ...
).transpose(0, 1)
GDN은 Q+K+V를 결합하여 한 번의 causal_conv1d_fn으로 처리하지만, KDA는 세 번 호출한다. 이는 KDA 모델의 Conv 가중치가 Q, K, V별로 독립적으로 학습되기 때문이다. 커널 호출 횟수는 3배이지만, 각 호출의 차원이 1/3이므로 총 연산량은 유사하다.
Forward Decode: Conv State Transpose
def forward_decode(self, layer, mixed_qkv, a, b, **kwargs):
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer.layer_id)
conv_states = layer_cache.conv[0]
ssm_states = layer_cache.temporal
qkv = causal_conv1d_update(
mixed_qkv,
conv_states.transpose(-1, -2), # KDA 특유의 transpose
layer.conv_weights, layer.bias,
activation="silu",
conv_state_indices=cache_indices,
)
q, k, v = qkv.split([layer.q_dim, layer.k_dim, layer.v_dim], dim=-1)
q = q.unflatten(-1, (-1, layer.head_q_dim)).unsqueeze(0)
k = k.unflatten(-1, (-1, layer.head_k_dim)).unsqueeze(0)
v = v.unflatten(-1, (-1, layer.head_v_dim)).unsqueeze(0)
KDA는 conv_states를 transpose(-1, -2)하여 사용한다. GDN은 Conv State를 [pool_size, conv_dim, conv_kernel] 형태로 저장하지만, KDA는 [pool_size, conv_kernel, conv_dim] 형태를 기대한다. 이 차이는 KDA 모델의 Causal Conv 구현이 다른 메모리 레이아웃을 사용하기 때문이다.
Forward Extend: Triton Conv만 사용
# KDA always uses the triton causal_conv1d_fn (no CUDA override).
# Only causal_conv1d_update needs platform-specific overrides for decode.
소스 코드 주석에서 명시한 것처럼, KDA의 Extend는 항상 Triton 구현의 causal_conv1d_fn을 사용한다. CUDA 네이티브 구현으로의 오버라이드는 Decode의 causal_conv1d_update에만 적용된다. 이는 KDA의 개별 Q/K/V Convolution이 CUDA 네이티브 커널의 최적화 경로와 맞지 않기 때문이다.
Extend 커널: g와 beta 직접 전달
def forward_extend(self, layer, forward_batch, mixed_qkv, a, b, **kwargs):
# ... Q, K, V 개별 conv 처리 후 ...
core_attn_out = self.kernel_dispatcher.extend(
q=q, k=k, v=v,
g=a, # KDA: a를 gate로 직접 사용
beta=b, # KDA: b를 beta로 직접 사용
ssm_states=ssm_states,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
)
GDN은 fused_gdn_gating(A_log, a, b, dt_bias)로 gate와 beta를 계산하지만, KDA는 a와 b를 직접 g와 beta로 전달한다. KDA 모델에서는 a가 이미 gate의 역할을 하도록 학습되었기 때문이다. 이 차이가 KDA의 구조적 단순함을 만든다.
GDN vs KDA 비교
| 특성 | GDN | KDA |
|---|---|---|
| 게이팅 | fused_gdn_gating 별도 계산 |
a, b 직접 사용 |
| Convolution | Q+K+V 합쳐서 1회 | Q, K, V 분리하여 3회 |
| Conv State 레이아웃 | [pool, dim, kernel] |
[pool, kernel, dim] (transpose) |
| Decode 커널 | Triton / FlashInfer / CuTe DSL | Triton / CuTe DSL |
| Extend 커널 | Triton / FlashInfer | Triton만 |
| Packed Decode | 지원 (FlashInfer) | 미지원 |
| Target Verify | 지원 (별도 verify 커널) | 미구현 |
설계 근거
KDA가 GDN보다 간결한 이유는 모델 아키텍처 자체의 차이에서 온다. GDN은 A_log, dt_bias를 포함하는 SSM 스타일의 파라미터화를 사용하여 별도 게이팅 연산이 필요하다. KDA는 Convolution과 게이팅을 분리하여, 프로젝션 레이어에서 이미 적절한 a, b 값을 만들어낸다. 서빙 관점에서 KDA는 구현이 단순하지만 커널 최적화 여지는 GDN보다 적다.
관련 포스트
- GDN (Gated Diagonal Net): 게이트 기반 선형 어텐션
- FLA (Flashy Linear Attention): 청크 기반 선형 어텐션 연산
- Lightning Attention: 고속 선형 어텐션 구현
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] GDN (Gated Diagonal Net): 게이트 기반 선형 어텐션
- 현재글 : [SGLang] KDA (Kernel-Driven Attention): 커널 기반 선형 어텐션
- 다음글 [SGLang] Lightning Attention: 고속 선형 어텐션 구현
댓글