[sglang] SGLang: DeepSeek-R1 FP8 GEMM 성능 회귀 문제 해결 및 최적화
PR 링크: sgl-project/sglang#28073 상태: Merged | 변경: +106 / -15
들어가며
최근 SGLang에서 MiniMax-M2 모델의 정확도 문제를 해결하기 위해 도입된 변경사항이, 의도치 않게 DeepSeek-R1-0528 모델에서 FP8 GEMM 연산이 Triton 커널로 불필요하게 fallback 되는 성능 회귀(regression)를 유발했습니다. 본 PR은 이 문제를 해결하고, 추론 성능을 이전 수준으로 복구하기 위해 GEMM 백엔드 선택 로직을 개선했습니다.
코드 분석
1. python/sglang/srt/layers/quantization/fp8_utils.py
기존 로직은 weight_scale에 format_ue8m0 속성이 있는지 여부를 기준으로 TRTLLM 백엔드 사용 여부를 결정했습니다. 하지만 DeepSeek-R1 모델의 경우, 이 속성이 없더라도 TRTLLM 커널을 사용하는 것이 성능상 유리합니다. 이를 dtype 기반의 체크로 변경하여 불필요한 fallback을 방지했습니다.
Before:
if backend == "trtllm" and (
input_2d.shape[1] < 256 or not getattr(weight_scale, "format_ue8m0", False)
):
return triton_w8a8_block_fp8_linear(...)
After:
if backend == "trtllm" and (
input_2d.shape[1] < 256 or input_2d.dtype != torch.bfloat16
):
return triton_w8a8_block_fp8_linear(...)
2. python/sglang/srt/model_loader/utils.py
이전 PR에서 리베이스 과정 중 포함된 불필요한 데드 코드(post_load_weights)를 제거하여 코드베이스를 정리했습니다.
왜 이게 좋은가
이번 최적화는 단순히 코드를 정리하는 것을 넘어, 실제 추론 처리량(Throughput)을 크게 개선했습니다. 벤치마크 결과에 따르면, 성능 회귀가 발생했던 버전의 출력 토큰 처리량은 약 510 tok/s였으나, 본 수정 이후 이전의 정상 수준인 650 tok/s 이상으로 복구되었습니다.
핵심 교훈:
- 조건부 로직의 정교화: 특정 모델의 특수 포맷(
format_ue8m0)에 의존하기보다, 연산의 안정성과 성능을 보장하는 데이터 타입(bfloat16)을 기준으로 백엔드를 선택하는 것이 더 범용적입니다. - 회귀 테스트의 중요성: 이번 PR에서 추가된
test_flashinfer_trtllm_fp8_fallback.py와 같은 단위 테스트는 향후 유사한 성능 회귀가 발생하지 않도록 방어하는 중요한 안전장치가 됩니다.
참고 자료
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [flashinfer] FlashInfer, SM120 GPU를 위한 희소 MLA 커널 추가로 LLM 추론 속도 향상
- 현재글 : [sglang] SGLang: DeepSeek-R1 FP8 GEMM 성능 회귀 문제 해결 및 최적화
- 다음글 [cpython] CPython unicodedata.normalize() 최적화: Py_UCS4 버퍼 직접 조작으로 성능 향상
댓글