[sglang] AMD ROCm 환경에서의 성능 최적화: Triton을 활용한 Fused QK GemmaRMSNorm 구현
PR 링크: sgl-project/sglang#23575 상태: Merged | 변경: +None / -None
들어가며
LLM 추론 엔진인 SGLang에서 모델의 성능은 커널 실행 횟수와 밀접한 관련이 있습니다. 특히 AMD ROCm 플랫폼 환경에서 프로파일링을 수행한 결과, apply_qk_norm 함수가 4개의 개별 커널을 호출하고 있음을 확인했습니다. 이는 CUDA 플랫폼에서 2개의 커널이 오버랩되어 실행되는 것과 비교했을 때 비효율적이며, 결과적으로 E2E(End-to-End) 지연 시간을 증가시키는 원인이 됩니다. 본 PR은 이 4개의 커널을 하나의 Triton 커널로 통합(Fusion)하여 오버헤드를 줄이고 성능을 최적화하는 과정을 다룹니다.
코드 분석
1. python/sglang/srt/models/utils.py: Triton 커널 구현
핵심은 _fused_qk_gemma_rmsnorm_kernel을 작성하여 Q와 K의 정규화를 하나의 커널에서 처리하는 것입니다. 기존에는 여러 단계의 연산이 필요했으나, 이제는 triton.jit을 통해 단일 커널에서 메모리 로드와 연산을 수행합니다.
# Before: 개별 연산 수행 (추상화됨)
# After: 단일 커널에서 Q와 K 정규화 통합
@triton.jit
def _fused_qk_gemma_rmsnorm_kernel(
Q_ptr, K_ptr, Q_out_ptr, K_out_ptr, ...
):
pid = tl.program_id(0)
# Q norm (every block)
# ... (Q 정규화 로직)
# K norm (first k_rows blocks only)
if pid < k_rows:
# ... (K 정규화 로직)
특히 q_stride와 k_stride를 커널에 전달함으로써, 메모리 상에서 비연속적인(non-contiguous) 텐서도 추가적인 .contiguous() 복사 없이 직접 처리할 수 있게 설계되었습니다.
2. python/sglang/srt/models/qwen3_5.py: ROCm 환경 적용
_is_hip 플래그를 확인하여 ROCm 환경에서만 이 최적화된 커널을 사용하도록 조건부 로직을 추가했습니다.
+ elif _is_hip:
+ q_by_head, k_by_head = fused_qk_gemma_rmsnorm(
+ q, k, self.q_norm.weight.data, self.k_norm.weight.data, ...
+ )
왜 이게 좋은가
이번 최적화의 핵심은 커널 실행 횟수 감소와 메모리 접근 최적화입니다.
- 커널 오버헤드 감소: 4개의 커널을 1개로 통합함으로써 커널 실행 시 발생하는 고정 비용(Launch Overhead)을 획기적으로 줄였습니다.
- 메모리 효율성:
tl.load시 스트라이드를 직접 활용하여 불필요한 메모리 복사를 제거했습니다.
성능 테스트 결과, 동시성(Concurrency)이 4~16인 환경에서 약 1.9%에서 2.6%의 처리량(Throughput) 향상을 보였습니다. 이는 대규모 모델 추론 시 누적되는 지연 시간을 고려할 때 매우 의미 있는 수치입니다.
교훈: GPU 커널 최적화 시, 단순히 연산량을 줄이는 것보다 커널 간의 의존성을 줄이고 메모리 접근 패턴을 최적화하여 커널 실행 횟수를 최소화하는 것이 E2E 성능 향상에 훨씬 효과적입니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang: ROCm 환경에서 Qwen3-VL 디코딩 성능 극대화를 위한 커널 퓨전 최적화
- [sglang] SGLang Triton 커널 최적화: libdevice.tanh 도입과 2D Strided Tensor 지원
- [sglang] SGLang에서 GLM-5 모델 성능 최적화: Aiter 백엔드 활용 및 텐서 패딩 전략
- [sglang] SGLang의 AMD AITER AllReduce 최적화: 하드코딩된 제약 제거 및 성능 개선
- [sglang] SGLang의 AMD GPU 최적화: RMSNorm과 FP8 Per-token Quantization 커널 융합
PR Analysis 의 다른글
- 이전글 [sglang] SGLang MoE 라우팅 최적화: AMD GPU에서 aiter.biased_grouped_topk 활용
- 현재글 : [sglang] AMD ROCm 환경에서의 성능 최적화: Triton을 활용한 Fused QK GemmaRMSNorm 구현
- 다음글 [sglang] SGLang 성능 최적화: torch.cuda.empty_cache() 호출 제어를 통한 가중치 업데이트 병목 해결
댓글