본문으로 건너뛰기

[SGLang] KDA (Kernel-Driven Attention): 커널 기반 선형 어텐션

들어가며

KDA(Kimi Delta Attention)는 Moonshot AI의 Kimi 모델에서 사용하는 선형 어텐션 메커니즘이다. GDN과 유사하게 Delta Rule 기반이지만, 게이트가 head-wise 차원에 적용되는 구조적 차이가 있다. SGLang의 KDAAttnBackend는 Triton과 CuTe DSL 두 가지 커널을 지원하며, GDN 대비 더 간결한 구조를 가진다.

이 글에서는 python/sglang/srt/layers/attention/linear/kda_backend.py를 분석한다.

구조도

┌─────────────────────────────────────────────────┐
│                 KDAAttnBackend                   │
│                                                 │
│  ┌───────────────────────────────────────┐      │
│  │         KDAKernelDispatcher           │      │
│  │  ┌──────────┐    ┌──────────────┐     │      │
│  │  │ decode   │    │   extend     │     │      │
│  │  │ _kernel  │    │   _kernel    │     │      │
│  │  └────┬─────┘    └──────┬───────┘     │      │
│  │       │                 │             │      │
│  │  ┌────▼─────┐     ┌────▼─────┐       │      │
│  │  │ Triton / │     │  Triton  │       │      │
│  │  │ CuteDSL  │     │  (only)  │       │      │
│  │  └──────────┘     └──────────┘       │      │
│  └───────────────────────────────────────┘      │
│                                                 │
│  forward_decode()                               │
│    → causal_conv1d_update (Triton fallback)     │
│    → split q,k,v → unflatten → decode kernel   │
│                                                 │
│  forward_extend()                               │
│    → split conv_weights per q,k,v               │
│    → causal_conv1d_fn × 3                       │
│    → extend kernel                              │
└─────────────────────────────────────────────────┘

핵심 코드 분석

KDAKernelDispatcher: GDN 대비 제한적 커널 조합

class KDAKernelDispatcher:
    def __init__(self, decode_backend, prefill_backend):
        triton_kernel = TritonKDAKernel()

        if decode_backend.is_triton():
            self.decode_kernel = triton_kernel
        elif decode_backend.is_cutedsl():
            if not is_cuda():
                raise ValueError("KDA CuTe DSL backend requires CUDA")
            from ...kernels.kda_cutedsl import CuteDSLKDAKernel
            self.decode_kernel = CuteDSLKDAKernel()
        else:
            raise ValueError(
                f"Unsupported KDA decode backend: {decode_backend}. "
                "KDA currently only supports 'triton'."
            )

        if prefill_backend.is_triton():
            self.extend_kernel = triton_kernel
        else:
            raise ValueError(
                f"Unsupported KDA prefill backend: {prefill_backend}. "
                "KDA currently only supports 'triton'."
            )

GDN과 비교하면 KDA의 커널 지원은 제한적이다. Decode에 Triton과 CuTe DSL, Extend에는 Triton만 사용할 수 있다. FlashInfer 백엔드는 아직 KDA를 지원하지 않는다. 이는 KDA가 GDN 대비 후발 주자이기 때문이다.

KDA vs GDN: Causal Convolution의 차이

GDN은 mixed_qkv를 하나로 묶어서 Convolution하지만, KDA는 Q, K, V를 개별적으로 Convolution한다.

# KDA: Conv 가중치를 Q, K, V별로 분리
q_conv_weight, k_conv_weight, v_conv_weight = layer.conv_weights.split(
    splits, dim=0
)
q_conv_state, k_conv_state, v_conv_state = conv_states.split(splits, dim=-2)

# Q, K, V 각각 별도 causal_conv1d_fn 호출
q = causal_conv1d_fn(
    q, q_conv_weight, q_bias, activation="silu",
    conv_states=q_conv_state,
    has_initial_state=has_initial_state,
    cache_indices=cache_indices,
    query_start_loc=query_start_loc,
    seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
k = causal_conv1d_fn(
    k, k_conv_weight, k_bias, activation="silu",
    conv_states=k_conv_state, ...
).transpose(0, 1)
v = causal_conv1d_fn(
    v, v_conv_weight, v_bias, activation="silu",
    conv_states=v_conv_state, ...
).transpose(0, 1)

GDN은 Q+K+V를 결합하여 한 번의 causal_conv1d_fn으로 처리하지만, KDA는 세 번 호출한다. 이는 KDA 모델의 Conv 가중치가 Q, K, V별로 독립적으로 학습되기 때문이다. 커널 호출 횟수는 3배이지만, 각 호출의 차원이 1/3이므로 총 연산량은 유사하다.

Forward Decode: Conv State Transpose

def forward_decode(self, layer, mixed_qkv, a, b, **kwargs):
    layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer.layer_id)
    conv_states = layer_cache.conv[0]
    ssm_states = layer_cache.temporal

    qkv = causal_conv1d_update(
        mixed_qkv,
        conv_states.transpose(-1, -2),  # KDA 특유의 transpose
        layer.conv_weights, layer.bias,
        activation="silu",
        conv_state_indices=cache_indices,
    )
    q, k, v = qkv.split([layer.q_dim, layer.k_dim, layer.v_dim], dim=-1)
    q = q.unflatten(-1, (-1, layer.head_q_dim)).unsqueeze(0)
    k = k.unflatten(-1, (-1, layer.head_k_dim)).unsqueeze(0)
    v = v.unflatten(-1, (-1, layer.head_v_dim)).unsqueeze(0)

KDA는 conv_statestranspose(-1, -2)하여 사용한다. GDN은 Conv State를 [pool_size, conv_dim, conv_kernel] 형태로 저장하지만, KDA는 [pool_size, conv_kernel, conv_dim] 형태를 기대한다. 이 차이는 KDA 모델의 Causal Conv 구현이 다른 메모리 레이아웃을 사용하기 때문이다.

Forward Extend: Triton Conv만 사용

# KDA always uses the triton causal_conv1d_fn (no CUDA override).
# Only causal_conv1d_update needs platform-specific overrides for decode.

소스 코드 주석에서 명시한 것처럼, KDA의 Extend는 항상 Triton 구현의 causal_conv1d_fn을 사용한다. CUDA 네이티브 구현으로의 오버라이드는 Decode의 causal_conv1d_update에만 적용된다. 이는 KDA의 개별 Q/K/V Convolution이 CUDA 네이티브 커널의 최적화 경로와 맞지 않기 때문이다.

Extend 커널: g와 beta 직접 전달

def forward_extend(self, layer, forward_batch, mixed_qkv, a, b, **kwargs):
    # ... Q, K, V 개별 conv 처리 후 ...

    core_attn_out = self.kernel_dispatcher.extend(
        q=q, k=k, v=v,
        g=a,      # KDA: a를 gate로 직접 사용
        beta=b,   # KDA: b를 beta로 직접 사용
        ssm_states=ssm_states,
        cache_indices=cache_indices,
        query_start_loc=query_start_loc,
    )

GDN은 fused_gdn_gating(A_log, a, b, dt_bias)로 gate와 beta를 계산하지만, KDA는 ab를 직접 gbeta로 전달한다. KDA 모델에서는 a가 이미 gate의 역할을 하도록 학습되었기 때문이다. 이 차이가 KDA의 구조적 단순함을 만든다.

GDN vs KDA 비교

특성 GDN KDA
게이팅 fused_gdn_gating 별도 계산 a, b 직접 사용
Convolution Q+K+V 합쳐서 1회 Q, K, V 분리하여 3회
Conv State 레이아웃 [pool, dim, kernel] [pool, kernel, dim] (transpose)
Decode 커널 Triton / FlashInfer / CuTe DSL Triton / CuTe DSL
Extend 커널 Triton / FlashInfer Triton만
Packed Decode 지원 (FlashInfer) 미지원
Target Verify 지원 (별도 verify 커널) 미구현

설계 근거

KDA가 GDN보다 간결한 이유는 모델 아키텍처 자체의 차이에서 온다. GDN은 A_log, dt_bias를 포함하는 SSM 스타일의 파라미터화를 사용하여 별도 게이팅 연산이 필요하다. KDA는 Convolution과 게이팅을 분리하여, 프로젝션 레이어에서 이미 적절한 a, b 값을 만들어낸다. 서빙 관점에서 KDA는 구현이 단순하지만 커널 최적화 여지는 GDN보다 적다.

관련 포스트

  • GDN (Gated Diagonal Net): 게이트 기반 선형 어텐션
  • FLA (Flashy Linear Attention): 청크 기반 선형 어텐션 연산
  • Lightning Attention: 고속 선형 어텐션 구현

참고

댓글

관련 포스트

SGLang 의 다른글