[sglang] SGLang의 Qwen3.5 성능 극대화: Fused QK GemmaRMSNorm + RoPE 커널 최적화 분석
PR 링크: sgl-project/sglang#28320 상태: Merged | 변경: +244 / -7
들어가며
최신 대규모 언어 모델(LLM)의 추론 성능을 최적화하는 핵심은 메모리 대역폭 병목을 줄이는 것입니다. 특히 Qwen3.5와 같은 모델은 GemmaRMSNorm, RoPE, 그리고 게이트 디인터리브(deinterleave) 연산이 개별적으로 수행될 때 빈번한 메모리 읽기/쓰기(Global Memory Access)가 발생합니다. 본 PR은 이러한 파편화된 연산들을 하나의 Triton 커널로 통합(Fusion)하여, GPU의 연산 효율을 극대화하고 추론 지연 시간을 줄이는 최적화를 수행했습니다.
코드 분석
1. Triton 커널 통합 (python/sglang/srt/layers/fused_qk_rmsnorm_rope_gate.py)
기존에는 deinterleave -> RMSNorm -> RoPE 순으로 여러 번의 커널 호출이 필요했습니다. 이를 하나의 커널로 통합하여 L1 캐시를 최대한 활용하도록 설계되었습니다.
Before (개념적 구조):
# 별도의 커널 호출들
q, gate = deinterleave(q_gate)
q = gemma_rmsnorm(q)
q = apply_rope(q)
After (Fused Triton Kernel):
@triton.jit
def _fused_qk_rmsnorm_rope_gate_kernel(...):
# ... (RMSNorm 계산)
x_norm = (x * inv_rms * (w + 1.0)).to(out_dtype)
# ... (RoPE 적용)
tl.store(out_base + rot_offs, (xr1 * cos - xr2 * sin), mask=rot_mask)
# ... (Gate 복사)
if HAS_GATE and not is_k:
tl.store(gate_out + head_offs, g, mask=head_mask)
이 방식의 핵심은 in_base에서 데이터를 한 번 로드한 뒤, 레지스터와 L1 캐시 내에서 모든 변환(Norm, RoPE, Gate 분리)을 마친 후 최종 결과만 글로벌 메모리에 쓰는 것입니다.
2. 모델 레이어 적용 (python/sglang/srt/models/qwen3_5.py)
모델의 forward_prepare 단계에서 기존의 파이토치 연산 대신 새로 작성한 커널을 호출하도록 변경되었습니다.
# After: Fused 커널 호출
q_out, k_out, gate_out = fused_qk_gemma_rmsnorm_rope_gate(
q_gate, k, self.q_norm.weight.data, ...
)
왜 이게 좋은가
성능 개선 수치
- B200 TP=8 환경: 4 Concurrency에서 9.4%의 TPS(Tokens Per Second) 향상을 기록했습니다.
- GB300 TP=4 환경: 배치 사이즈 2에서 8.1%, 배치 사이즈 16에서 6.1%의 처리량 향상을 보였습니다.
기술적 교훈
- Memory Bound 최적화: LLM 추론은 연산량보다 메모리 대역폭이 병목인 경우가 많습니다. 여러 연산을 하나로 묶어 중간 결과를 메모리에 쓰지 않는 것만으로도 상당한 성능 이득을 얻을 수 있습니다.
- Hardware-Specific Optimization: 리뷰 과정에서 언급된
tl.extra.cuda.gdc_launch_dependents()를 통해 NVIDIA Hopper 아키텍처 이상에서 지원하는 PDL(Programmatic Dependent Launch)을 활용하여 커널 간 의존성을 최적화했습니다. 이는 하드웨어 기능을 적극 활용하는 좋은 사례입니다. - Triton의 유연성: 복잡한 텐서 레이아웃(interleaved Q+Gate)을 Triton의 2D Grid와
tl.program_id를 활용해 효율적으로 처리할 수 있음을 보여줍니다.
리뷰 피드백 반영
리뷰어 zcnrex는 하드웨어 호환성을 위해 USE_PDL 가드와 tl.extra.cuda.gdc_launch_dependents()의 직접적인 사용을 제안했습니다. 이에 따라 AMD 등 타 하드웨어에서의 런타임 에러를 방지하고 최신 NVIDIA GPU에서의 성능을 보장하는 견고한 코드가 완성되었습니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] Qwen3.5 및 Qwen3_Next 모델의 NPU 성능 향상을 위한 Triton 커널 퓨전 최적화
- [sglang] SGLang의 MLA KV 캐시 쓰기 최적화: TMA Bulk-Store 도입
- [sglang] SGLang Triton 커널 최적화: libdevice.tanh 도입과 2D Strided Tensor 지원
- [sglang] SGLang의 디코드 성능 향상을 위한 Temperature 및 Softmax 커널 융합
- [sglang] [NPU] GLM-4.7-Flash 성능 최적화: Fused Triton 커널로 연산 병목 해결하기
PR Analysis 의 다른글
- 이전글 [loki] Grafana Loki 엔진의 집계 성능 최적화: 메모리 할당 감소와 효율적인 라벨 처리
- 현재글 : [sglang] SGLang의 Qwen3.5 성능 극대화: Fused QK GemmaRMSNorm + RoPE 커널 최적화 분석
- 다음글 [onnxruntime] [ONNX Runtime] SGEMM의 함정에서 벗어나기: GQA 전용 GEMV 커널을 통한 디코딩 최적화
댓글