본문으로 건너뛰기

[sglang] SGLang의 MHC 파이프라인 최적화: 커널 퓨전과 DeepGemm 도입

PR 링크: sgl-project/sglang#24775 상태: Merged | 변경: +0 / -0

들어가며

LLM 서빙 엔진인 SGLang에서 Multi-Head Latent Attention(MHC) 파이프라인은 연산 집약적인 작업 중 하나입니다. 기존 구현은 여러 개의 독립적인 커널 호출과 중간 결과물에 대한 HBM(High Bandwidth Memory) 접근이 잦아 병목 현상이 발생했습니다. 이번 PR은 커널 퓨전(Kernel Fusion)과 DeepGemm을 도입하여 이러한 오버헤드를 획기적으로 줄였습니다.

코드 분석

1. MHC Pre-GEMM 및 RMSNorm 퓨전

가장 핵심적인 변경 사항은 mhc_pre_big_fuse_with_norm_tilelang 커널의 도입입니다. 기존에는 GEMM 연산 후 별도의 RMSNorm 커널을 호출했으나, 이를 하나로 통합했습니다.

Before (개념적):

# 별도 커널 호출
gemm_out = run_gemm(residual)
norm_out = run_rms_norm(gemm_out)

After:

# TileLang을 이용한 퓨전 커널
@tilelang.jit
def mhc_pre_big_fuse_with_norm_tilelang(...):
    # ... (GEMM 연산)
    # 중간 결과를 Shared Memory에 유지하며 RMSNorm 적용
    for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2):
        # ... (Norm 연산 후 HBM에 한 번만 기록)

이 방식을 통해 HBM으로의 불필요한 쓰기/읽기 왕복을 제거했습니다.

2. DeepGemm 도입

SGLANG_OPT_DEEPGEMM_HC_PRENORM 환경 변수가 활성화된 경우, DeepGemm의 tf32_hc_prenorm_gemm을 사용하여 GEMM 연산 효율을 높였습니다.

if envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.get():
    import deep_gemm
    deep_gemm.tf32_hc_prenorm_gemm(residual_flat, fn_flat, gemm_out_mul, gemm_out_sqrsum, num_splits=n_splits)

왜 이게 좋은가

이번 최적화는 단순히 연산 속도뿐만 아니라 메모리 대역폭 효율성을 극대화했습니다.

  1. 커널 퓨전: RMSNormmhc_pre 커널 내부로 통합함으로써, 중간 데이터가 HBM으로 나갔다가 다시 들어오는 비용을 제거했습니다. 이는 특히 토큰 수가 많을 때(2048 이상) 성능 향상이 두드러집니다.
  2. DeepGemm 활용: TF32 연산을 최적화된 라이브러리로 처리하여 하드웨어 가속을 극대화했습니다.
  3. 성능 수치: 마이크로 벤치마크 결과, hc_head 연산에서 토큰 수 2048 기준 기존 대비 약 2.5배2.7배의 속도 향상을 보였습니다. norm + mhc_pre 파이프라인 역시 전반적으로 1.3배1.9배의 속도 개선을 달성했습니다.

교훈: GPU 연산 최적화의 핵심은 '연산량 감소'보다 '메모리 접근 횟수 최소화'에 있습니다. 커널을 퓨전하여 Shared Memory 내에서 데이터를 처리하는 것이 현대적인 LLM 가속의 정석임을 다시 한번 확인시켜 줍니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글