[sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
PR 링크: sgl-project/sglang#25976 상태: Merged | 변경: +875 / -48
들어가며
대규모 언어 모델(LLM)의 추론 성능은 실시간 서비스 제공에 있어 매우 중요한 요소입니다. 특히, 모델의 각 레이어를 통과하는 과정에서 발생하는 연산들을 얼마나 효율적으로 처리하느냐가 전체적인 Latency와 Throughput에 큰 영향을 미칩니다. sglang 레포지토리의 이번 PR ([DeepSeek-V4] Add mhc_fused_post_pre kernel)은 DeepSeek-V4 모델의 Multi-Head Causal Attention (mHC) 경로를 최적화하여 추론 성능을 향상시키는 것을 목표로 합니다.
기존에는 mHC의 post-step과 pre-step 사이에 별도의 커널 실행 및 스케줄링 경계가 존재했습니다. 이는 특히 Latency에 민감한 디코드(decode) 경로에서 불필요한 오버헤드를 발생시켰습니다. 이 PR은 이 경계를 '융합(fuse)'하여 하나의 커널에서 처리함으로써 이러한 오버헤드를 제거하고 성능을 개선합니다.
이번 글에서는 이 PR의 핵심 변경 사항을 살펴보고, 왜 이러한 최적화가 효과적인지, 그리고 어떤 기술적 교훈을 얻을 수 있는지 분석해 보겠습니다.
코드 변경사항 분석
1. 환경 변수 추가 (python/sglang/srt/environ.py)
새로운 최적화 기능을 활성화하기 위한 환경 변수가 추가되었습니다.
--- a/python/sglang/srt/environ.py
+++ b/python/sglang/srt/environ.py
@@ -651,6 +651,7 @@ class Envs:
SGLANG_OPT_USE_TILELANG_MHC_PRE = EnvBool(True)
SGLANG_OPT_USE_TILELANG_MHC_POST = EnvBool(True)
SGLANG_OPT_USE_TRITON_FUSED_MHC = EnvBool(True)
+ SGLANG_OPT_FUSE_MHC_POST_PRE = EnvBool(False)
SGLANG_OPT_USE_TILELANG_INDEXER = EnvBool(False)
SGLANG_OPT_USE_AITER_INDEXER = EnvBool(False)
SGLANG_OPT_USE_JIT_INDEXER_METADATA = EnvBool(True)
SGLANG_OPT_FUSE_MHC_POST_PRE 환경 변수가 False로 기본 설정되어 있으며, 이를 1로 설정하면 새로운 융합 커널이 활성화됩니다. 이 최적화는 기존의 TileLang 기반 mHC 커널(SGLANG_OPT_USE_TILELANG_MHC_PRE, SGLANG_OPT_USE_TILELANG_MHC_POST)에 의존하므로, 해당 옵션들이 활성화된 상태에서 사용될 때 가장 효과적입니다.
2. Fused mHC Kernel 구현 (python/sglang/srt/layers/mhc.py)
이 PR의 핵심은 mhc_fused_post_pre 함수와 이를 지원하는 mhc_fused_post_pre_fma_tilelang 커널입니다. 이 커널은 기존의 mhc_post와 mhc_pre 단계를 하나의 연산으로 융합합니다.
mhc_fused_post_pre_fma_tilelang:
이 함수는 작은 토큰 배치에 최적화된 TileLang 스칼라-FMA(Fused Multiply-Add) 커널입니다. 이전 mHC post-step의 결과(prev_comb_mix, prev_residual, prev_post_mix)와 현재 입력(hidden_in), 그리고 pre-norm GEMM 연산(pre_fn)을 받아, 하나의 커널 실행으로 다음과 같은 연산을 수행합니다:
- 이전
hc_post결과 계산 - bf16 잔차(residual) 재구성
- pre-norm GEMM의 부분 결과(partials) 계산
- RMS 정규화에 필요한 제곱합(square-sum)의 부분 결과 계산
--- a/python/sglang/srt/layers/mhc.py
+++ b/python/sglang/srt/layers/mhc.py
@@ -896,3 +896,500 @@ def mhc_post(
residual.shape[-1],
)
return out
+
+
+@tilelang.jit(
+ pass_configs={
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
+ tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
+ },
+) # ... (mhc_fused_post_pre_fma_tilelang 전체 코드)
이 커널은 스레드 블록(CTA) 내에서 토큰, 출력 믹스 컬럼 타일, 히든 차원 분할(split)을 기준으로 작업을 할당합니다. 각 스레드는 특정 히든 위치에 대한 연산을 수행하며, 워프(warp) 내 및 워프 간 통신을 통해 부분 결과들을 취합합니다. 특히, mix_output_tile_idx == 0인 경우에만 잔차와 제곱합을 계산하여 중복을 피합니다.
mhc_fused_post_pre:
이 함수는 실제 디스패치 로직을 담당합니다. 토큰 배치 크기에 따라 두 가지 경로로 나뉩니다:
- 작은 배치:
mhc_fused_post_pre_fma_tilelang커널을 호출합니다. - 큰 배치: 기존의 고성능 TileLang
mhc_post커널과 DeepGEMM 기반의mhc_pre경로를 그대로 사용합니다.
또한, 이 함수는 기존의 mhc_pre_big_fuse(_with_norm) 최종 단계를 유지하며, 0 토큰 DP/EP 랭크를 위한 빈 텐서 처리 등 엣지 케이스를 처리합니다.
3. 모델 통합 (python/sglang/srt/models/deepseek_v4.py 및 deepseek_v4_nextn.py)
DeepSeek-V4 모델 자체에서도 이 융합 기능을 활용하도록 수정되었습니다.
deepseek_v4.py:- 융합된 mHC 활성화 여부를 캐싱합니다.
- 레이어 간(
cross-layer)hc_post와hc_pre를 융합합니다. - 레이어 내(
within-layer)hc_post와hc_pre를 융합합니다. - 최종
hc_post는 모델 테일에서 처리하도록 지연시킵니다. - 융합 경로에서 per-forward 캐스팅을 피하기 위해 bf16 RMSNorm 가중치를 캐싱합니다.
# 예시: deepseek_v4.py 내의 융합 로직 일부 (실제 diff와 다를 수 있음)
# ...
if self.fused_mhc_enabled:
# Fuse cross-layer hc_post + hc_pre
# Fuse within-layer hc_post + hc_pre
# ...
# Defer final FFN hc_post
# ...
else:
# Original path
# ...
# ...
deepseek_v4_nextn.py:- NextN 디코더 사용 시, 지연된 융합 상태를 반환하도록 수정되었습니다.
4. 테스트 코드 보강 (tests/kernels/test_mhc_kernels.py)
새로운 융합 커널의 정확성을 검증하기 위해 테스트 코드가 추가되었습니다. 기존의 비융합(unfused) 경로와 비교하여 정확성을 확인합니다.
리뷰 과정에서 hc_mult 값 조정, 큰 배치 크기 테스트 추가, hidden_size 파라미터화(예: 7168), 융합 커널에 대한 워밍업(warmup) 추가 등 다양한 개선 제안이 있었고, 이에 따라 테스트 커버리지가 강화되었습니다.
--- a/tests/kernels/test_mhc_kernels.py
+++ b/tests/kernels/test_mhc_kernels.py
@@ -12,7 +12,7 @@
def test_mhc_kernels():
# Test mhc_post
- for hc_mult, hc_scale, hc_base, hc_pre_eps, hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, n_splits, tile_n in [
+ for hc_mult, hc_scale, hc_base, hc_pre_eps, hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, n_splits, tile_n in [
(1, 1.0, 1.0, 1e-5, 1e-5, 1.0, 1, 1, 1),
(2, 1.0, 1.0, 1e-5, 1e-5, 1.0, 1, 1, 1),
(4, 1.0, 1.0, 1e-5, 1e-5, 1.0, 1, 1, 1),
@@ -20,7 +20,7 @@
(4, 1.0, 1.0, 1e-5, 1e-5, 1.0, 1, 2, 1),
(4, 1.0, 1.0, 1e-5, 1e-5, 1.0, 1, 1, 2),
]:
- for hidden_size in [128, 768]:
+ for hidden_size in [128, 768, 4096]: # Added 4096 for larger tests
for num_tokens in [1, 16, 128]:
for batch_size in [1, 8]:
# ... (rest of the test code)
리뷰어 yhyang201의 제안에 따라 hc_mult=4로 설정하고, hidden_size에 4096을 추가하여 더 큰 모델 구성에 대한 테스트 커버리지를 확보했습니다. 또한, 큰 배치 크기(예: 64)를 테스트하여 원본 비융합 경로의 정확성도 검증하도록 했습니다.
왜 이게 좋은가?
이 PR의 핵심적인 개선은 다음과 같습니다.
- Latency 감소: mHC post-step과 pre-step 사이의 불필요한 커널 실행 및 스케줄링 오버헤드를 제거했습니다. 이는 특히 토큰 생성 과정과 같이 Latency가 중요한 시나리오에서 상당한 성능 향상을 가져올 수 있습니다. PR 설명에 따르면, total_throughput이 3.35% 증가하는 성능 향과가 있었습니다.
- 연산 효율성 증대: 작은 토큰 배치에서는 하나의 FMA 커널에서 여러 연산을 동시에 처리하여 GPU 활용률을 높입니다. 큰 배치에서는 기존의 고성능 커널 조합을 유지하면서도, 융합 가능한 부분을 통합하여 효율성을 개선합니다.
- 메모리 대역폭 절감: 커널 간 데이터 이동을 줄임으로써 메모리 대역폭 사용량을 최적화합니다. 이는 특히 메모리 대역폭이 병목 현상을 일으키는 LLM 추론에서 중요한 이점입니다.
- 정확성 유지: 성능 향상과 더불어 GSM8K 데이터셋에서 0.975의 높은 정확도를 유지하며, 연산 수준에서의 정밀도 검증도 통과했습니다. 이는 최적화가 모델의 예측 결과에 부정적인 영향을 미치지 않음을 의미합니다.
일반적 교훈
- 커널 융합(Kernel Fusion)의 힘: Latency에 민감한 경로에서 여러 작은 커널을 하나의 큰 커널로 융합하는 것은 GPU 연산에서 매우 효과적인 최적화 기법입니다. 커널 실행 오버헤드와 메모리 접근 패턴을 개선하여 성능을 크게 향상시킬 수 있습니다.
- 조건부 최적화: 모든 경우에 단일 최적화 기법이 최선은 아닙니다. 이 PR처럼 입력 크기(토큰 배치)에 따라 다른 최적화 경로(작은 배치는 FMA 융합, 큰 배치는 기존 고성능 경로)를 선택하는 것은 일반적인 패턴입니다.
- 테스트의 중요성: 새로운 최적화 기능은 기존 기능의 동작을 변경할 수 있으므로, 다양한 시나리오(작은/큰 배치, 다양한 모델 크기)에 대한 철저한 테스트와 정확성 검증이 필수적입니다. 리뷰 과정에서의 피드백을 통해 테스트 커버리지를 넓히는 것이 중요합니다.
결론
이번 PR은 DeepSeek-V4 모델의 mHC 경로를 최적화하여 추론 성능을 3.35% 향상시키는 중요한 개선을 이루었습니다. 커널 융합이라는 강력한 기법을 적용하여 Latency를 줄이고 연산 효율성을 높였으며, 동시에 모델의 정확성을 유지했습니다. 이는 LLM 추론 최적화 분야에서 커널 융합의 중요성과 효과를 다시 한번 보여주는 좋은 사례입니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/mhc.py#L896
- https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/deepseek_v4.py
- https://github.com/sgl-project/sglang/blob/main/tests/kernels/test_mhc_kernels.py
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [vllm] vLLM, GDN Prefill 커널을 CuteDSL로 최적화하여 성능 향상
- [flashinfer] FlashInfer, 동적 토큰 페이지 커널 도입으로 TRTLLM-GEN GQA 성능 최적화
- [flashinfer] FlashInfer, CUDA 그래프 호환성을 높이고 성능을 최적화하다: TRT-LLM FMHA v2 통합 및 불필요한 H2D 제거
- [vllm] vLLM, DCP A2A 어텐션 백엔드 최적화: 단일 All-to-All 콜렉티브로 성능 향상
- [sglang] sglang, GLM-5.1-FP8 모델 성능 및 정확도 벤치마크 추가: AMD GPU 환경에서의 최적화 분석
PR Analysis 의 다른글
- 이전글 [vllm] vLLM, DeepSeek-V3.2 모델의 ROCm 성능 최적화: CPU 측 마이크로 최적화 3가지 분석
- 현재글 : [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- 다음글 [sglang] SGLang의 add_constant 커널 최적화: 아키텍처 인지 벡터화(Vectorization) 도입
댓글