[sglang] SGLang의 MHC 파이프라인 최적화: 커널 퓨전과 DeepGemm 도입
PR 링크: sgl-project/sglang#24775 상태: Merged | 변경: +0 / -0
들어가며
LLM 서빙 엔진인 SGLang에서 Multi-Head Latent Attention(MHC) 파이프라인은 연산 집약적인 작업 중 하나입니다. 기존 구현은 여러 개의 독립적인 커널 호출과 중간 결과물에 대한 HBM(High Bandwidth Memory) 접근이 잦아 병목 현상이 발생했습니다. 이번 PR은 커널 퓨전(Kernel Fusion)과 DeepGemm을 도입하여 이러한 오버헤드를 획기적으로 줄였습니다.
코드 분석
1. MHC Pre-GEMM 및 RMSNorm 퓨전
가장 핵심적인 변경 사항은 mhc_pre_big_fuse_with_norm_tilelang 커널의 도입입니다. 기존에는 GEMM 연산 후 별도의 RMSNorm 커널을 호출했으나, 이를 하나로 통합했습니다.
Before (개념적):
# 별도 커널 호출
gemm_out = run_gemm(residual)
norm_out = run_rms_norm(gemm_out)
After:
# TileLang을 이용한 퓨전 커널
@tilelang.jit
def mhc_pre_big_fuse_with_norm_tilelang(...):
# ... (GEMM 연산)
# 중간 결과를 Shared Memory에 유지하며 RMSNorm 적용
for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2):
# ... (Norm 연산 후 HBM에 한 번만 기록)
이 방식을 통해 HBM으로의 불필요한 쓰기/읽기 왕복을 제거했습니다.
2. DeepGemm 도입
SGLANG_OPT_DEEPGEMM_HC_PRENORM 환경 변수가 활성화된 경우, DeepGemm의 tf32_hc_prenorm_gemm을 사용하여 GEMM 연산 효율을 높였습니다.
if envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.get():
import deep_gemm
deep_gemm.tf32_hc_prenorm_gemm(residual_flat, fn_flat, gemm_out_mul, gemm_out_sqrsum, num_splits=n_splits)
왜 이게 좋은가
이번 최적화는 단순히 연산 속도뿐만 아니라 메모리 대역폭 효율성을 극대화했습니다.
- 커널 퓨전:
RMSNorm을mhc_pre커널 내부로 통합함으로써, 중간 데이터가 HBM으로 나갔다가 다시 들어오는 비용을 제거했습니다. 이는 특히 토큰 수가 많을 때(2048 이상) 성능 향상이 두드러집니다. - DeepGemm 활용: TF32 연산을 최적화된 라이브러리로 처리하여 하드웨어 가속을 극대화했습니다.
- 성능 수치: 마이크로 벤치마크 결과,
hc_head연산에서 토큰 수 2048 기준 기존 대비 약 2.5배2.7배의 속도 향상을 보였습니다.1.9배의 속도 개선을 달성했습니다.norm + mhc_pre파이프라인 역시 전반적으로 1.3배
교훈: GPU 연산 최적화의 핵심은 '연산량 감소'보다 '메모리 접근 횟수 최소화'에 있습니다. 커널을 퓨전하여 Shared Memory 내에서 데이터를 처리하는 것이 현대적인 LLM 가속의 정석임을 다시 한번 확인시켜 줍니다.
참고 자료
- https://github.com/sgl-project/sglang
- https://docs.nvidia.com/cuda/turing-tuning-guide/index.html#shared-memory-optimizations
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang Triton 커널 최적화: libdevice.tanh 도입과 2D Strided Tensor 지원
- [sglang] SGLang의 디코드 성능 향상을 위한 Temperature 및 Softmax 커널 융합
- [sglang] SGLang 성능 최적화: PDL 도입과 안전한 CUDA 동기화로 DSV3.2/GLM-5 가속하기
- [vllm] vLLM chunk_kda 커널의 숨겨진 상태(h) 레이아웃 불일치 버그 수정 및 정확도 개선
- [sglang] SGLang 성능 최적화: torch.cuda.empty_cache() 호출 제어를 통한 가중치 업데이트 병목 해결
PR Analysis 의 다른글
- 이전글 [openclaw] Telegram 메시지 캐시 최적화: 전체 파일 재작성 대신 변경분만 기록하기
- 현재글 : [sglang] SGLang의 MHC 파이프라인 최적화: 커널 퓨전과 DeepGemm 도입
- 다음글 [cpython] CPython inspect.getattr_static 성능 개선: 일반적인 메타클래스 사례 최적화
댓글