본문으로 건너뛰기

[sglang] AMD ROCm 환경에서의 성능 최적화: Triton을 활용한 Fused QK GemmaRMSNorm 구현

PR 링크: sgl-project/sglang#23575 상태: Merged | 변경: +None / -None

들어가며

LLM 추론 엔진인 SGLang에서 모델의 성능은 커널 실행 횟수와 밀접한 관련이 있습니다. 특히 AMD ROCm 플랫폼 환경에서 프로파일링을 수행한 결과, apply_qk_norm 함수가 4개의 개별 커널을 호출하고 있음을 확인했습니다. 이는 CUDA 플랫폼에서 2개의 커널이 오버랩되어 실행되는 것과 비교했을 때 비효율적이며, 결과적으로 E2E(End-to-End) 지연 시간을 증가시키는 원인이 됩니다. 본 PR은 이 4개의 커널을 하나의 Triton 커널로 통합(Fusion)하여 오버헤드를 줄이고 성능을 최적화하는 과정을 다룹니다.

코드 분석

1. python/sglang/srt/models/utils.py: Triton 커널 구현

핵심은 _fused_qk_gemma_rmsnorm_kernel을 작성하여 Q와 K의 정규화를 하나의 커널에서 처리하는 것입니다. 기존에는 여러 단계의 연산이 필요했으나, 이제는 triton.jit을 통해 단일 커널에서 메모리 로드와 연산을 수행합니다.

# Before: 개별 연산 수행 (추상화됨)
# After: 단일 커널에서 Q와 K 정규화 통합
@triton.jit
def _fused_qk_gemma_rmsnorm_kernel(
    Q_ptr, K_ptr, Q_out_ptr, K_out_ptr, ...
):
    pid = tl.program_id(0)
    # Q norm (every block)
    # ... (Q 정규화 로직)
    
    # K norm (first k_rows blocks only)
    if pid < k_rows:
        # ... (K 정규화 로직)

특히 q_stridek_stride를 커널에 전달함으로써, 메모리 상에서 비연속적인(non-contiguous) 텐서도 추가적인 .contiguous() 복사 없이 직접 처리할 수 있게 설계되었습니다.

2. python/sglang/srt/models/qwen3_5.py: ROCm 환경 적용

_is_hip 플래그를 확인하여 ROCm 환경에서만 이 최적화된 커널을 사용하도록 조건부 로직을 추가했습니다.

+        elif _is_hip:
+            q_by_head, k_by_head = fused_qk_gemma_rmsnorm(
+                q, k, self.q_norm.weight.data, self.k_norm.weight.data, ...
+            )

왜 이게 좋은가

이번 최적화의 핵심은 커널 실행 횟수 감소메모리 접근 최적화입니다.

  1. 커널 오버헤드 감소: 4개의 커널을 1개로 통합함으로써 커널 실행 시 발생하는 고정 비용(Launch Overhead)을 획기적으로 줄였습니다.
  2. 메모리 효율성: tl.load 시 스트라이드를 직접 활용하여 불필요한 메모리 복사를 제거했습니다.

성능 테스트 결과, 동시성(Concurrency)이 4~16인 환경에서 약 1.9%에서 2.6%의 처리량(Throughput) 향상을 보였습니다. 이는 대규모 모델 추론 시 누적되는 지연 시간을 고려할 때 매우 의미 있는 수치입니다.

교훈: GPU 커널 최적화 시, 단순히 연산량을 줄이는 것보다 커널 간의 의존성을 줄이고 메모리 접근 패턴을 최적화하여 커널 실행 횟수를 최소화하는 것이 E2E 성능 향상에 훨씬 효과적입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글