본문으로 건너뛰기

[sglang] SGLang: DeepSeek-R1 FP8 GEMM 성능 회귀 문제 해결 및 최적화

PR 링크: sgl-project/sglang#28073 상태: Merged | 변경: +106 / -15

들어가며

최근 SGLang에서 MiniMax-M2 모델의 정확도 문제를 해결하기 위해 도입된 변경사항이, 의도치 않게 DeepSeek-R1-0528 모델에서 FP8 GEMM 연산이 Triton 커널로 불필요하게 fallback 되는 성능 회귀(regression)를 유발했습니다. 본 PR은 이 문제를 해결하고, 추론 성능을 이전 수준으로 복구하기 위해 GEMM 백엔드 선택 로직을 개선했습니다.

코드 분석

1. python/sglang/srt/layers/quantization/fp8_utils.py

기존 로직은 weight_scaleformat_ue8m0 속성이 있는지 여부를 기준으로 TRTLLM 백엔드 사용 여부를 결정했습니다. 하지만 DeepSeek-R1 모델의 경우, 이 속성이 없더라도 TRTLLM 커널을 사용하는 것이 성능상 유리합니다. 이를 dtype 기반의 체크로 변경하여 불필요한 fallback을 방지했습니다.

Before:

if backend == "trtllm" and (
    input_2d.shape[1] < 256 or not getattr(weight_scale, "format_ue8m0", False)
):
    return triton_w8a8_block_fp8_linear(...)

After:

if backend == "trtllm" and (
    input_2d.shape[1] < 256 or input_2d.dtype != torch.bfloat16
):
    return triton_w8a8_block_fp8_linear(...)

2. python/sglang/srt/model_loader/utils.py

이전 PR에서 리베이스 과정 중 포함된 불필요한 데드 코드(post_load_weights)를 제거하여 코드베이스를 정리했습니다.

왜 이게 좋은가

이번 최적화는 단순히 코드를 정리하는 것을 넘어, 실제 추론 처리량(Throughput)을 크게 개선했습니다. 벤치마크 결과에 따르면, 성능 회귀가 발생했던 버전의 출력 토큰 처리량은 약 510 tok/s였으나, 본 수정 이후 이전의 정상 수준인 650 tok/s 이상으로 복구되었습니다.

핵심 교훈:

  1. 조건부 로직의 정교화: 특정 모델의 특수 포맷(format_ue8m0)에 의존하기보다, 연산의 안정성과 성능을 보장하는 데이터 타입(bfloat16)을 기준으로 백엔드를 선택하는 것이 더 범용적입니다.
  2. 회귀 테스트의 중요성: 이번 PR에서 추가된 test_flashinfer_trtllm_fp8_fallback.py와 같은 단위 테스트는 향후 유사한 성능 회귀가 발생하지 않도록 방어하는 중요한 안전장치가 됩니다.

참고 자료

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글