본문으로 건너뛰기

[SGLang] FLA (Flashy Linear Attention): 청크 기반 선형 어텐션 연산

들어가며

FLA(Flash Linear Attention)는 선형 어텐션의 청크 기반 연산을 Triton 커널로 구현한 라이브러리이다. 원래 fla-org/flash-linear-attention에서 개발되었으며, SGLang이 서빙에 필요한 핵심 연산을 가져와 사용한다. GDN, KDA 등 선형 어텐션 백엔드의 실제 계산을 수행하는 저수준 커널 모음이다.

이 글에서는 python/sglang/srt/layers/attention/fla/ 디렉토리의 핵심 파일들을 분석한다.

구조도

fla/
├── chunk.py               ← Gated Delta Rule 진입점
├── chunk_delta_h.py        ← 청크 간 상태 전파 (h 업데이트)
├── chunk_fwd.py            ← 청크 내 Forward (kkt + solve_tril)
├── chunk_intra.py          ← Intra-chunk 어텐션
├── chunk_o.py              ← 출력 계산
├── cumsum.py               ← 청크별 누적합
├── fused_recurrent.py      ← Fused Recurrent 커널 (Decode용)
├── fused_gdn_gating.py     ← GDN 게이팅 Fused 커널
├── fused_norm_gate.py      ← LayerNorm + Gate Fused 커널
├── op.py                   ← Triton 기본 연산 (exp, log, gather)
├── index.py                ← 청크 인덱스 유틸리티
├── l2norm.py               ← L2 Normalization
├── utils.py                ← 공통 유틸리티
└── wy_fast.py              ← WY 분해 기반 w, u 재계산

           Chunk Pipeline 흐름
┌──────────────────────────────────────────┐
│  입력: q, k, v, g (gate), beta          │
│                                          │
│  1. chunk_local_cumsum(g)                │
│     └─ gate의 청크별 누적합 계산          │
│                                          │
│  2. chunk_gated_delta_rule_fwd_intra     │
│     └─ beta*K@K^T → solve_tril → w, u, A│
│     └─ 청크 내 삼각 시스템 풀기           │
│                                          │
│  3. chunk_gated_delta_rule_fwd_h         │
│     └─ 청크 간 상태 전파 (h 업데이트)     │
│     └─ initial_state 활용                │
│                                          │
│  4. chunk_fwd_o                          │
│     └─ q @ (h + intra_attn) → output    │
│                                          │
│  출력: o (attention output), h (states)  │
└──────────────────────────────────────────┘

핵심 코드 분석

chunk.py: Gated Delta Rule의 진입점

def chunk_gated_delta_rule_fwd(
    q, k, v, g, beta, scale,
    initial_state, initial_state_indices,
    cu_seqlens=None, chunk_indices=None,
):
    g = chunk_local_cumsum(g, chunk_size=CHUNK_SIZE, cu_seqlens=cu_seqlens)

    # 청크 내 연산: beta * K @ K^T + 삼각 시스템 풀기
    w, u, A = chunk_gated_delta_rule_fwd_intra(
        k=k, v=v, g=g, beta=beta,
        cu_seqlens=cu_seqlens, chunk_indices=chunk_indices,
    )

    # 청크 간 상태 전파
    h, v_new = chunk_gated_delta_rule_fwd_h(
        k=k, w=w, u=u, g=g,
        initial_state=initial_state,
        initial_state_indices=initial_state_indices,
        cu_seqlens=cu_seqlens,
    )

    # 최종 출력 계산
    o = chunk_fwd_o(
        q=q, k=k, v=v_new, h=h, g=g, scale=scale,
        cu_seqlens=cu_seqlens,
    )
    return g, o, A, w, h, v_new

Gated Delta Rule의 Forward는 4단계 파이프라인이다. CHUNK_SIZE는 64로 고정되어 있다. chunk_local_cumsum은 gate g의 청크별 누적합을 계산하여 decay를 적용한다. 이후 청크 내 연산(intra), 청크 간 상태 전파(inter), 출력 계산을 순서대로 수행한다.

ChunkGatedDeltaRuleFunction: L2 Norm 지원

class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, g, beta, scale,
                initial_state, initial_state_indices,
                cu_seqlens=None, use_qk_l2norm_in_kernel=False):
        q_orig, k_orig = q, k
        if use_qk_l2norm_in_kernel:
            q = l2norm_fwd(q)
            k = l2norm_fwd(k)

        chunk_indices = (
            prepare_chunk_indices(cu_seqlens, CHUNK_SIZE)
            if cu_seqlens is not None else None
        )
        g, o, A, w, h, v_new = chunk_gated_delta_rule_fwd(
            q=q, k=k, v=v, g=g, beta=beta, scale=scale,
            initial_state=initial_state,
            initial_state_indices=initial_state_indices,
            cu_seqlens=cu_seqlens, chunk_indices=chunk_indices,
        )
        return o.to(q.dtype), h

use_qk_l2norm_in_kernel이 True이면 Q, K에 L2 정규화를 적용한 후 Gated Delta Rule을 수행한다. L2 Norm은 어텐션 스코어의 수치 안정성을 높인다. torch.autograd.Function으로 감싸 학습 시 Backward도 지원하지만, SGLang 서빙에서는 Forward만 사용한다.

chunk_delta_h: 청크 간 상태 전파

@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
    k, v, w, v_new, g, gk, h,
    initial_state, initial_state_indices, cu_seqlens, chunk_offsets,
    T, H: tl.constexpr, Hg: tl.constexpr,
    K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BV: tl.constexpr,
    ...
):
    i_v, i_nh = tl.program_id(0), tl.program_id(1)
    i_n, i_h = i_nh // H, i_nh % H

    # [BV, BK] 상태 행렬 초기화
    b_h1 = tl.zeros([BV, 64], dtype=tl.float32)
    if K > 64:
        b_h2 = tl.zeros([BV, 64], dtype=tl.float32)
    if K > 128:
        b_h3 = tl.zeros([BV, 64], dtype=tl.float32)
    if K > 192:
        b_h4 = tl.zeros([BV, 64], dtype=tl.float32)

이 커널은 청크 간 상태(h)를 전파하는 핵심 연산이다. 상태 행렬의 크기가 K(head dimension)에 따라 달라지므로, K > 64, K > 128, K > 192 케이스를 수동으로 분기한다. tl.constexpr로 선언된 K가 컴파일 타임에 결정되므로, 불필요한 분기는 제거된다. Variable Length(IS_VARLEN) 모드에서는 cu_seqlenschunk_offsets로 시퀀스 경계를 추적한다.

fused_recurrent: Decode용 Fused Recurrent 커널

@triton.jit(do_not_specialize=["T"])
def fused_recurrent_gated_delta_rule_fwd_kernel(
    q, k, v, g, beta, o, h0, ht, cu_seqlens, scale,
    T, B: tl.constexpr, H: tl.constexpr, HV: tl.constexpr,
    K: tl.constexpr, V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
    ...
    IS_KDA: tl.constexpr,  # GDN vs KDA 구분
):
    b_h = tl.zeros([BV, BK], dtype=tl.float32)
    if USE_INITIAL_STATE:
        p_h0 = h0 + i_nh * V * K + o_v[:, None] * K + o_k[None, :]
        b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)

    for _ in range(0, T):
        b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
        b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
        b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
        b_q = b_q * scale

        if not IS_KDA:
            b_g = tl.load(p_g).to(tl.float32)
            b_h *= exp(b_g)          # GDN: 스칼라 gate
        else:
            b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
            b_h *= exp(b_gk[None, :])  # KDA: head-wise gate

        # Delta Rule 업데이트
        b_v -= tl.sum(b_h * b_k[None, :], 1)

이 커널은 GDN과 KDA 모두에서 사용되는 Fused Recurrent 연산이다. IS_KDA 플래그로 게이트 적용 방식을 구분한다. GDN은 스칼라 gate b_g를 상태 전체에 동일하게 적용하지만, KDA는 head-wise gate b_gk를 K 차원별로 다르게 적용한다. Delta Rule의 핵심인 b_v -= tl.sum(b_h * b_k[None, :], 1)는 현재 상태에서 key에 해당하는 정보를 제거한 후 새로운 value를 추가하는 연산이다.

chunk_fwd: Fused KKT + Solve Tril

@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_fwd_kkt_solve_kernel(
    k, g, beta, A, cu_seqlens, chunk_indices,
    T, H: tl.constexpr, Hg: tl.constexpr,
    K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, BK: tl.constexpr,
    ...
):
    """
    Fused kernel: compute beta * K @ K^T (lower triangular)
    + solve_tril (I+A)^{-1} in one pass.
    """

이 커널은 두 연산을 Fuse한다. 먼저 beta * K @ K^T의 하삼각 부분을 계산하고, 바로 이어서 (I + A)^{-1}을 삼각 시스템으로 풀어낸다. 원래는 두 개의 별도 커널이었지만, GPU 메모리 왕복을 줄이기 위해 하나로 합쳤다. TF32 정밀도(_MERGE_DOT_PRECISION)를 블록 병합 dot product에 사용하여 SM90에서 2배 빠르게 처리한다.

op.py: Fast Math 연산

if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
    exp = tldevice.fast_expf
    exp2 = tldevice.exp2
    log = tldevice.fast_logf
    log2 = tldevice.fast_log2f
else:
    exp = tl.exp
    exp2 = tl.math.exp2
    log = tl.log
    log2 = tl.log2

@triton.jit
def safe_exp(x):
    return exp(tl.where(x <= 0, x, float("-inf")))

FLA_USE_FAST_OPS=1로 설정하면 CUDA의 __expf 같은 빠른 근사 함수를 사용한다. 정밀도가 약간 떨어지지만 throughput이 높아진다. safe_exp는 양수 입력을 -inf로 클램프하여 오버플로를 방지한다. Gate 값이 log space에 있으므로 항상 음수여야 하기 때문이다.

fused_norm_gate: LayerNorm + Gate Fused 커널

@triton.jit
def layer_norm_gated_fwd_kernel(
    x, g, y, w, b, residual, residual_out,
    mean, rstd, eps, T,
    D: tl.constexpr, BT: tl.constexpr, BD: tl.constexpr,
    ACTIVATION: tl.constexpr, IS_RMS_NORM: tl.constexpr,
    STORE_RESIDUAL_OUT: tl.constexpr, HAS_RESIDUAL: tl.constexpr,
    HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr,
):

Mamba와 GDN 모델에서 자주 사용되는 패턴은 LayerNorm(x) * activation(gate)이다. 이 커널은 LayerNorm(또는 RMSNorm), 활성화 함수, 게이트 곱셈을 하나의 Triton 커널로 Fuse한다. Residual 연산도 선택적으로 포함할 수 있어, 메모리 왕복을 최소화한다.

설계 근거: 왜 Chunk 기반인가

선형 어텐션의 재귀적 상태 업데이트 h_t = decay * h_{t-1} + k_t * v_t^T는 본질적으로 순차적이다. 청크 기반 접근법은 이를 두 수준으로 나눈다:

  1. Intra-chunk: 청크 내부(64 토큰)는 행렬 곱으로 병렬화
  2. Inter-chunk: 청크 간 상태는 순차적으로 전파

64 토큰 청크는 GPU의 warp 크기(32)의 2배이고, shared memory에 상태 행렬을 올릴 수 있는 적절한 크기이다. 이 하이브리드 전략으로 완전 순차 대비 수십 배의 병렬성을 확보하면서도, 완전 병렬(Transformer 스타일) 대비 O(n) 메모리를 유지한다.

관련 포스트

  • GDN (Gated Diagonal Net): 게이트 기반 선형 어텐션
  • KDA (Kernel-Driven Attention): 커널 기반 선형 어텐션
  • Mamba (SSM): 선형 시간 복잡도 시퀀스 모델링

참고

댓글

관련 포스트

SGLang 의 다른글