[SGLang] Lightning Attention: 고속 선형 어텐션 구현
들어가며
Lightning Attention은 MiniMax에서 개발한 IO-aware 선형 어텐션이다. 선형 어텐션의 수학적 이점(O(n) 복잡도)을 유지하면서, 실제 GPU에서의 메모리 접근 패턴을 최적화한다. SGLang의 LightningAttentionBackend는 두 가지 구현 경로를 제공한다: MiniMax의 블록 기반 커널(minimax)과 Ant Group의 Segment Linear Attention(seg_la).
이 글에서는 python/sglang/srt/layers/attention/linear/lightning_backend.py를 분석한다.
구조도
┌──────────────────────────────────────────────────────┐
│ LightningAttentionBackend │
│ │
│ ┌──────────────┐ ┌──────────────────────────┐ │
│ │ tp_slope │ │ linear_backend 설정 │ │
│ │ (ALiBi-like │ │ "minimax" | "seg_la" │ │
│ │ decay) │ └──────────┬───────────────┘ │
│ └──────────────┘ │ │
│ ┌───────▼────────┐ │
│ ┌───────┤ 분기 선택 ├───────┐ │
│ │ └────────────────┘ │ │
│ ▼ ▼ │
│ ┌───────────────────────┐ ┌────────────────────┐ │
│ │ minimax (블록 기반) │ │ seg_la (세그먼트) │ │
│ │ _prefill_and_mix_infer│ │ _linear_attention │ │
│ │ _decode_infer │ │ _entry │ │
│ └───────────────────────┘ └────────────────────┘ │
└──────────────────────────────────────────────────────┘
핵심 코드 분석
Slope 텐서 구축: ALiBi 스타일의 Decay
@staticmethod
def _build_slope_tensor(n_attention_heads, num_hidden_layers, device="cuda"):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
slopes = torch.tensor(get_slopes(n_attention_heads), dtype=torch.float32)
.reshape(n_attention_heads, 1, 1)
# 레이어별로 다른 decay rate 적용
slope_rate_list = [
slopes * (1 - layer_id / (num_hidden_layers - 1) + 1e-5)
for layer_id in range(num_hidden_layers)
]
Lightning Attention은 ALiBi(Attention with Linear Biases)에서 영감받은 위치 decay를 사용한다. 헤드별로 다른 기하급수적 decay rate를 할당하고, 레이어가 깊어질수록 decay를 줄여 상위 레이어가 더 넓은 컨텍스트를 볼 수 있게 한다. TP(Tensor Parallelism) 환경에서는 각 rank가 자기 담당 헤드의 slope만 가져간다.
Prefill: 블록 기반 Intra + Inter 어텐션
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
forward_batch, layer, metadata):
hidden = []
for _prefill_idx in range(metadata.num_prefills):
_start = forward_batch.extend_start_loc[_prefill_idx]
if _prefill_idx + 1 < forward_batch.extend_start_loc.shape[0]:
_end = forward_batch.extend_start_loc[_prefill_idx + 1]
else:
_end = q.shape[0]
slot_id = state_indices_tensor[_prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = BailingLinearKernel.jit_linear_forward_prefix(
qs, ks, vs, slice_layer_cache,
self.tp_slope[layer.layer_id], self.BLOCK,
layer_idx=layer.layer_id,
)
hidden.append(out_slice.contiguous())
minimax 경로의 Prefill은 각 요청을 개별적으로 처리한다. BailingLinearKernel.jit_linear_forward_prefix는 시퀀스를 BLOCK 크기(기본 256)의 블록으로 나누어, 블록 내부(intra)는 삼각 어텐션으로, 블록 간(inter)은 KV 상태를 누적하여 계산한다. slice_layer_cache는 각 요청의 SSM 상태를 가리키며, 처리 후 업데이트된 상태가 저장된다.
Decode: Triton 커널로 단일 토큰 처리
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, metadata, layer):
num_prefill_tokens = metadata.num_prefill_tokens
num_prefills = metadata.num_prefills
q = q[num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[num_prefill_tokens:].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[num_prefills:]
hidden = linear_decode_forward_triton(
q, k, v, kv_cache,
self.tp_slope[layer.layer_id], slot_id, 32
)
return hidden
Decode에서는 Prefill 토큰을 건너뛰고 Decode 토큰만 추출한다. linear_decode_forward_triton은 Triton으로 작성된 커널로, 각 토큰에 대해 KV 상태를 읽고 → decay를 적용하고 → 새 k, v로 상태를 업데이트하고 → 출력을 계산한다. 마지막 인자 32는 블록 크기 파라미터이다.
Seg-LA: Segment Linear Attention 경로
def _linear_attention_entry(self, q, k, v, kv_cache, state_indices_tensor,
metadata, layer, mask=None, temp_cache=None,
intermediate_state_indices=None):
seg_meta = SegLaMeta(
batch_size=metadata.batch_size,
q_offsets=metadata.query_start_loc,
s_offsets=state_indices_tensor,
q_lengths=q_offsets.diff(),
s_scales=metadata.has_initial_states,
max_q_length=None,
mask=mask,
)
hidden = seg_la_fwd(
q=q, k=k, v=v, s=kv_cache,
decay_scales=self.tp_slope[layer.layer_id],
meta=seg_meta, caches=temp_cache,
cache_indices=intermediate_state_indices,
decouple=True,
)
return hidden
seg_la 경로는 Ant Group의 Segment Linear Attention을 사용한다. SegLaMeta에 배치 정보를 담고, seg_la_fwd가 단일 Triton 커널로 Prefill과 Decode를 모두 처리한다. s_scales는 초기 상태 유무를 나타내는 플래그로, Prefix Caching 시 이전 상태를 로드할지 결정한다. decouple=True는 key와 value의 decay를 분리 적용한다.
Mixed Batch: Prefill + Decode 통합
def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs):
if self.linear_backend == "minimax":
o = self._prefill_and_mix_infer(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k, v, ssm_states, cache_indices,
forward_batch, layer, metadata,
)
elif self.linear_backend == "seg_la":
o = self._linear_attention_entry(
q, k, v, ssm_states, cache_indices, metadata, layer,
temp_cache=(mamba_cache_params.intermediate_ssm
if forward_batch.forward_mode.is_target_verify() else None),
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
linear_backend 설정으로 두 경로를 선택한다. minimax는 Prefill과 Decode를 명시적으로 분리하여 각각 최적화된 커널을 호출한다. seg_la는 단일 커널이 두 모드를 모두 처리하여 코드가 간결하다. Target Verify 모드에서는 intermediate_ssm 캐시를 전달하여 중간 상태를 저장한다.
Intra-Block 어텐션: Diagonal 커널
# lightning_attn.py
@triton.jit
def _fwd_diag_kernel(
Q, K, V, Out, S,
b: tl.constexpr, h: tl.constexpr, n,
d: tl.constexpr, e: tl.constexpr,
BLOCK: tl.constexpr, NUM_BLOCK, CBLOCK: tl.constexpr,
):
off = tl.program_id(0)
off_bh = off // NUM_BLOCK
off_block = off % NUM_BLOCK
off_cblock = tl.program_id(1)
블록 내부(intra-block) 어텐션은 _fwd_diag_kernel로 계산한다. 블록 크기 BLOCK 안에서 삼각 마스크를 적용한 일반 어텐션을 수행한다. CBLOCK은 블록을 더 작은 서브블록으로 나누어 shared memory 사용을 최적화한다. IO-aware 설계의 핵심은 이 블록 크기를 GPU의 SRAM(shared memory)에 맞추는 것이다.
minimax vs seg_la 비교
| 특성 | minimax | seg_la |
|---|---|---|
| 개발사 | MiniMax | Ant Group (PIA) |
| Prefill 처리 | 요청별 개별 루프 | 단일 Fused 커널 |
| Decode 처리 | 별도 Triton 커널 | Prefill과 통합 |
| 구현 복잡도 | 높음 (2개 커널) | 낮음 (1개 커널) |
| Target Verify | 미지원 | 지원 (중간 상태 캐싱) |
| 블록 크기 | 256 (설정 가능) | 커널 내부 결정 |
관련 포스트
- GDN (Gated Diagonal Net): 게이트 기반 선형 어텐션
- KDA (Kernel-Driven Attention): 커널 기반 선형 어텐션
- FLA (Flashy Linear Attention): 청크 기반 선형 어텐션 연산
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] KDA (Kernel-Driven Attention): 커널 기반 선형 어텐션
- 현재글 : [SGLang] Lightning Attention: 고속 선형 어텐션 구현
- 다음글 [SGLang] FLA (Flashy Linear Attention): 청크 기반 선형 어텐션 연산
댓글