본문으로 건너뛰기

[SGLang] GDN의 kkt + solve_tril을 하나의 Triton 커널로 퓨전

PR 링크: sgl-project/sglang#21411 상태: Merged | 변경: +456 / -17

들어가며

Gated Delta Rule (GDN)은 linear attention의 변형으로, 각 청크 내에서 beta * K @ K^T의 하삼각 행렬을 계산(kkt)한 다음 (I+A)^{-1}을 forward substitution으로 풀어(solve_tril) intra-chunk attention을 구한다. 기존에는 이 두 단계가 별도 커널이어서 중간 행렬 A가 HBM에 쓰였다가 다시 읽혔다.

이 PR은 두 연산을 하나의 Triton 커널 chunk_gated_delta_rule_fwd_kkt_solve_kernel로 퓨전하여, 중간 결과를 레지스터에 유지한 채 바로 forward substitution까지 완료한다.

핵심 코드 분석

기존 3단계 파이프라인

Before:

# 단계 1: K@K^T 계산 (별도 커널)
A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, ...)

# 단계 2: 삼각 행렬 풀이 (별도 커널)
A = solve_tril(A=A, ...)

# 단계 3: w, u 재계산
w, u = recompute_w_u_fwd(k=k, v=v, beta=beta, A=A, g_cumsum=g, ...)

After:

# 단일 퓨전 커널이 kkt + solve_tril + recompute_w_u를 한 번에 처리
w, u, A = chunk_gated_delta_rule_fwd_intra(
    k=k, v=v, g=g, beta=beta,
    cu_seqlens=cu_seqlens, chunk_indices=chunk_indices,
)

퓨전 커널의 핵심 구조

BT=64 크기의 청크를 BC=16 크기의 4개 서브 청크로 나누어 10개의 하삼각 블록을 레지스터에서 계산한다.

# Step 1: 10개 [BC,BC] 블록의 K@K^T를 레지스터에서 계산
b_A00 = tl.zeros([BC, BC], dtype=tl.float32)  # 대각 블록 4개
b_A10 = tl.zeros([BC, BC], dtype=tl.float32)  # 비대각 블록 6개
...
for i_k in range(tl.cdiv(K, BK)):
    b_A00 += tl.dot(b_k0, tl.trans(b_k0))
    b_A10 += tl.dot(b_k1, tl.trans(b_k0))
    ...

# Step 2: gate와 beta 스케일링 (레지스터에서)
b_A00 *= safe_exp(b_g0[:, None] - b_g0[None, :])

# Step 3: 대각 블록 forward substitution (레지스터에서)
# Step 4: 블록 간 병합으로 전체 (I+A)^{-1} 완성

왜 이게 좋은가

  1. HBM 왕복 제거: 중간 행렬 A의 크기는 B*H*T*BT float이다. 이를 쓰고 읽는 대역폭이 완전히 절약된다.
  2. 레지스터 활용 극대화: 10개의 [16,16] 블록(각 256 float = 1KB)을 레지스터에 유지하여 L2 캐시 압력도 줄인다.
  3. Varlen 지원: chunk_indices를 통해 가변 길이 시퀀스 패킹에서도 올바른 청크 경계를 처리한다.

정리

커널 퓨전의 교과서적인 사례다. 두 개의 memory-bound 커널을 하나로 합쳐 중간 텐서의 HBM 왕복을 제거하고, 블록 단위 레지스터 타일링으로 compute efficiency를 높였다. GDN 기반 모델의 prefill 성능이 직접적으로 향상될 것이다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석과 해석에서 오류가 있을 수 있으니, 정확한 내용은 원본 PR을 참고해주세요.

댓글

관련 포스트

PR Analysis 의 다른글