[SGLang] GDN의 kkt + solve_tril을 하나의 Triton 커널로 퓨전
PR 링크: sgl-project/sglang#21411 상태: Merged | 변경: +456 / -17
들어가며
Gated Delta Rule (GDN)은 linear attention의 변형으로, 각 청크 내에서 beta * K @ K^T의 하삼각 행렬을 계산(kkt)한 다음 (I+A)^{-1}을 forward substitution으로 풀어(solve_tril) intra-chunk attention을 구한다. 기존에는 이 두 단계가 별도 커널이어서 중간 행렬 A가 HBM에 쓰였다가 다시 읽혔다.
이 PR은 두 연산을 하나의 Triton 커널 chunk_gated_delta_rule_fwd_kkt_solve_kernel로 퓨전하여, 중간 결과를 레지스터에 유지한 채 바로 forward substitution까지 완료한다.
핵심 코드 분석
기존 3단계 파이프라인
Before:
# 단계 1: K@K^T 계산 (별도 커널)
A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, ...)
# 단계 2: 삼각 행렬 풀이 (별도 커널)
A = solve_tril(A=A, ...)
# 단계 3: w, u 재계산
w, u = recompute_w_u_fwd(k=k, v=v, beta=beta, A=A, g_cumsum=g, ...)
After:
# 단일 퓨전 커널이 kkt + solve_tril + recompute_w_u를 한 번에 처리
w, u, A = chunk_gated_delta_rule_fwd_intra(
k=k, v=v, g=g, beta=beta,
cu_seqlens=cu_seqlens, chunk_indices=chunk_indices,
)
퓨전 커널의 핵심 구조
BT=64 크기의 청크를 BC=16 크기의 4개 서브 청크로 나누어 10개의 하삼각 블록을 레지스터에서 계산한다.
# Step 1: 10개 [BC,BC] 블록의 K@K^T를 레지스터에서 계산
b_A00 = tl.zeros([BC, BC], dtype=tl.float32) # 대각 블록 4개
b_A10 = tl.zeros([BC, BC], dtype=tl.float32) # 비대각 블록 6개
...
for i_k in range(tl.cdiv(K, BK)):
b_A00 += tl.dot(b_k0, tl.trans(b_k0))
b_A10 += tl.dot(b_k1, tl.trans(b_k0))
...
# Step 2: gate와 beta 스케일링 (레지스터에서)
b_A00 *= safe_exp(b_g0[:, None] - b_g0[None, :])
# Step 3: 대각 블록 forward substitution (레지스터에서)
# Step 4: 블록 간 병합으로 전체 (I+A)^{-1} 완성
왜 이게 좋은가
- HBM 왕복 제거: 중간 행렬 A의 크기는
B*H*T*BTfloat이다. 이를 쓰고 읽는 대역폭이 완전히 절약된다. - 레지스터 활용 극대화: 10개의 [16,16] 블록(각 256 float = 1KB)을 레지스터에 유지하여 L2 캐시 압력도 줄인다.
- Varlen 지원:
chunk_indices를 통해 가변 길이 시퀀스 패킹에서도 올바른 청크 경계를 처리한다.
정리
커널 퓨전의 교과서적인 사례다. 두 개의 memory-bound 커널을 하나로 합쳐 중간 텐서의 HBM 왕복을 제거하고, 블록 단위 레지스터 타일링으로 compute efficiency를 높였다. GDN 기반 모델의 prefill 성능이 직접적으로 향상될 것이다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석과 해석에서 오류가 있을 수 있으니, 정확한 내용은 원본 PR을 참고해주세요.
관련 포스트
PR Analysis 의 다른글
- 이전글 [sglang] SGLang Whisper 모델의 CUDA Graph 도입 및 성능 최적화 분석
- 현재글 : [SGLang] GDN의 kkt + solve_tril을 하나의 Triton 커널로 퓨전
- 다음글 [CPython] sqlite3 콜백 컨텍스트의 메모리 관리 버그 수정
댓글