본문으로 건너뛰기

[SGLang] Lightning Attention: 고속 선형 어텐션 구현

들어가며

Lightning Attention은 MiniMax에서 개발한 IO-aware 선형 어텐션이다. 선형 어텐션의 수학적 이점(O(n) 복잡도)을 유지하면서, 실제 GPU에서의 메모리 접근 패턴을 최적화한다. SGLang의 LightningAttentionBackend는 두 가지 구현 경로를 제공한다: MiniMax의 블록 기반 커널(minimax)과 Ant Group의 Segment Linear Attention(seg_la).

이 글에서는 python/sglang/srt/layers/attention/linear/lightning_backend.py를 분석한다.

구조도

┌──────────────────────────────────────────────────────┐
│              LightningAttentionBackend                │
│                                                      │
│  ┌──────────────┐     ┌──────────────────────────┐   │
│  │   tp_slope   │     │    linear_backend 설정    │   │
│  │  (ALiBi-like │     │  "minimax" | "seg_la"    │   │
│  │   decay)     │     └──────────┬───────────────┘   │
│  └──────────────┘                │                   │
│                          ┌───────▼────────┐          │
│                  ┌───────┤   분기 선택     ├───────┐  │
│                  │       └────────────────┘       │  │
│                  ▼                                ▼  │
│  ┌───────────────────────┐   ┌────────────────────┐  │
│  │  minimax (블록 기반)   │   │  seg_la (세그먼트)  │  │
│  │  _prefill_and_mix_infer│   │  _linear_attention │  │
│  │  _decode_infer        │   │  _entry             │  │
│  └───────────────────────┘   └────────────────────┘  │
└──────────────────────────────────────────────────────┘

핵심 코드 분석

Slope 텐서 구축: ALiBi 스타일의 Decay

@staticmethod
def _build_slope_tensor(n_attention_heads, num_hidden_layers, device="cuda"):
    def get_slopes(n):
        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio**i for i in range(n)]

        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)
        else:
            closest_power_of_2 = 2 ** math.floor(math.log2(n))
            return (
                get_slopes_power_of_2(closest_power_of_2)
                + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
            )

    slopes = torch.tensor(get_slopes(n_attention_heads), dtype=torch.float32)
        .reshape(n_attention_heads, 1, 1)

    # 레이어별로 다른 decay rate 적용
    slope_rate_list = [
        slopes * (1 - layer_id / (num_hidden_layers - 1) + 1e-5)
        for layer_id in range(num_hidden_layers)
    ]

Lightning Attention은 ALiBi(Attention with Linear Biases)에서 영감받은 위치 decay를 사용한다. 헤드별로 다른 기하급수적 decay rate를 할당하고, 레이어가 깊어질수록 decay를 줄여 상위 레이어가 더 넓은 컨텍스트를 볼 수 있게 한다. TP(Tensor Parallelism) 환경에서는 각 rank가 자기 담당 헤드의 slope만 가져간다.

Prefill: 블록 기반 Intra + Inter 어텐션

def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
                            forward_batch, layer, metadata):
    hidden = []
    for _prefill_idx in range(metadata.num_prefills):
        _start = forward_batch.extend_start_loc[_prefill_idx]
        if _prefill_idx + 1 < forward_batch.extend_start_loc.shape[0]:
            _end = forward_batch.extend_start_loc[_prefill_idx + 1]
        else:
            _end = q.shape[0]

        slot_id = state_indices_tensor[_prefill_idx]
        qs = q[_start:_end].transpose(0, 1).contiguous()
        ks = k[_start:_end].transpose(0, 1).contiguous()
        vs = v[_start:_end].transpose(0, 1).contiguous()
        slice_layer_cache = kv_cache[slot_id, ...]

        out_slice = BailingLinearKernel.jit_linear_forward_prefix(
            qs, ks, vs, slice_layer_cache,
            self.tp_slope[layer.layer_id], self.BLOCK,
            layer_idx=layer.layer_id,
        )
        hidden.append(out_slice.contiguous())

minimax 경로의 Prefill은 각 요청을 개별적으로 처리한다. BailingLinearKernel.jit_linear_forward_prefix는 시퀀스를 BLOCK 크기(기본 256)의 블록으로 나누어, 블록 내부(intra)는 삼각 어텐션으로, 블록 간(inter)은 KV 상태를 누적하여 계산한다. slice_layer_cache는 각 요청의 SSM 상태를 가리키며, 처리 후 업데이트된 상태가 저장된다.

Decode: Triton 커널로 단일 토큰 처리

def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, metadata, layer):
    num_prefill_tokens = metadata.num_prefill_tokens
    num_prefills = metadata.num_prefills
    q = q[num_prefill_tokens:].unsqueeze(2).contiguous()
    k = k[num_prefill_tokens:].unsqueeze(2).contiguous()
    v = v[num_prefill_tokens:].unsqueeze(2).contiguous()
    slot_id = state_indices_tensor[num_prefills:]

    hidden = linear_decode_forward_triton(
        q, k, v, kv_cache,
        self.tp_slope[layer.layer_id], slot_id, 32
    )
    return hidden

Decode에서는 Prefill 토큰을 건너뛰고 Decode 토큰만 추출한다. linear_decode_forward_triton은 Triton으로 작성된 커널로, 각 토큰에 대해 KV 상태를 읽고 → decay를 적용하고 → 새 k, v로 상태를 업데이트하고 → 출력을 계산한다. 마지막 인자 32는 블록 크기 파라미터이다.

Seg-LA: Segment Linear Attention 경로

def _linear_attention_entry(self, q, k, v, kv_cache, state_indices_tensor,
                             metadata, layer, mask=None, temp_cache=None,
                             intermediate_state_indices=None):
    seg_meta = SegLaMeta(
        batch_size=metadata.batch_size,
        q_offsets=metadata.query_start_loc,
        s_offsets=state_indices_tensor,
        q_lengths=q_offsets.diff(),
        s_scales=metadata.has_initial_states,
        max_q_length=None,
        mask=mask,
    )
    hidden = seg_la_fwd(
        q=q, k=k, v=v, s=kv_cache,
        decay_scales=self.tp_slope[layer.layer_id],
        meta=seg_meta, caches=temp_cache,
        cache_indices=intermediate_state_indices,
        decouple=True,
    )
    return hidden

seg_la 경로는 Ant Group의 Segment Linear Attention을 사용한다. SegLaMeta에 배치 정보를 담고, seg_la_fwd가 단일 Triton 커널로 Prefill과 Decode를 모두 처리한다. s_scales는 초기 상태 유무를 나타내는 플래그로, Prefix Caching 시 이전 상태를 로드할지 결정한다. decouple=True는 key와 value의 decay를 분리 적용한다.

Mixed Batch: Prefill + Decode 통합

def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs):
    if self.linear_backend == "minimax":
        o = self._prefill_and_mix_infer(
            q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
            k, v, ssm_states, cache_indices,
            forward_batch, layer, metadata,
        )
    elif self.linear_backend == "seg_la":
        o = self._linear_attention_entry(
            q, k, v, ssm_states, cache_indices, metadata, layer,
            temp_cache=(mamba_cache_params.intermediate_ssm
                        if forward_batch.forward_mode.is_target_verify() else None),
        )
    return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

linear_backend 설정으로 두 경로를 선택한다. minimax는 Prefill과 Decode를 명시적으로 분리하여 각각 최적화된 커널을 호출한다. seg_la는 단일 커널이 두 모드를 모두 처리하여 코드가 간결하다. Target Verify 모드에서는 intermediate_ssm 캐시를 전달하여 중간 상태를 저장한다.

Intra-Block 어텐션: Diagonal 커널

# lightning_attn.py
@triton.jit
def _fwd_diag_kernel(
    Q, K, V, Out, S,
    b: tl.constexpr, h: tl.constexpr, n,
    d: tl.constexpr, e: tl.constexpr,
    BLOCK: tl.constexpr, NUM_BLOCK, CBLOCK: tl.constexpr,
):
    off = tl.program_id(0)
    off_bh = off // NUM_BLOCK
    off_block = off % NUM_BLOCK
    off_cblock = tl.program_id(1)

블록 내부(intra-block) 어텐션은 _fwd_diag_kernel로 계산한다. 블록 크기 BLOCK 안에서 삼각 마스크를 적용한 일반 어텐션을 수행한다. CBLOCK은 블록을 더 작은 서브블록으로 나누어 shared memory 사용을 최적화한다. IO-aware 설계의 핵심은 이 블록 크기를 GPU의 SRAM(shared memory)에 맞추는 것이다.

minimax vs seg_la 비교

특성 minimax seg_la
개발사 MiniMax Ant Group (PIA)
Prefill 처리 요청별 개별 루프 단일 Fused 커널
Decode 처리 별도 Triton 커널 Prefill과 통합
구현 복잡도 높음 (2개 커널) 낮음 (1개 커널)
Target Verify 미지원 지원 (중간 상태 캐싱)
블록 크기 256 (설정 가능) 커널 내부 결정

관련 포스트

  • GDN (Gated Diagonal Net): 게이트 기반 선형 어텐션
  • KDA (Kernel-Driven Attention): 커널 기반 선형 어텐션
  • FLA (Flashy Linear Attention): 청크 기반 선형 어텐션 연산

참고

댓글

관련 포스트

SGLang 의 다른글