[sglang] SGLang Diffusion 모델의 FP8 GEMM 최적화: 41.5% 성능 향상 달성
PR 링크: sgl-project/sglang#27590 상태: Merged | 변경: +153 / -15
들어가며
최근 대규모 생성형 모델, 특히 Diffusion 모델을 서빙할 때 연산 효율성은 매우 중요한 요소입니다. 기존 SGLang의 WeightOnlyFP8Linear 구현은 매 forward pass마다 FP8 가중치를 compute dtype(FP16/BF16)으로 dequantize한 뒤 F.linear를 호출하는 방식을 사용했습니다. 이 과정은 메모리 대역폭과 연산 오버헤드를 유발하여 전체 추론 속도를 저하시키는 병목이 되었습니다. 본 PR은 이 과정을 최적화하기 위해 기존의 Fused FP8 GEMM 커널을 활용하도록 개선했습니다.
코드 분석
1. weight_only_fp8.py의 핵심 로직 변경
가장 중요한 변경은 _apply_weight_only_fp8_linear 함수입니다. 기존에는 무조건 dequantize를 수행했으나, 이제는 환경 변수를 통해 Fused GEMM 사용 여부를 결정하고, 가능한 경우 커널을 직접 호출합니다.
# Before
def forward(self, x: torch.Tensor) -> torch.Tensor:
compute_dtype = self.compute_dtype or x.dtype
weight = dequantize_rowwise_fp8_weight(self.weight, self.weight_scale, compute_dtype)
return F.linear(x.to(compute_dtype), weight, bias)
# After
def _apply_weight_only_fp8_linear(...):
if enable_fused_w8a8 and _can_apply_fused_w8a8_fp8_linear(...):
return _apply_srt_w8a8_fp8_linear(input=x, weight=weight.t(), ...)
dequant_weight = dequantize_rowwise_fp8_weight(weight, weight_scale, compute_dtype)
return F.linear(x, dequant_weight, bias)
이 변경을 통해 가중치를 미리 dequantize하지 않고, 커널 내부에서 W8A8(Weight 8-bit, Activation 8-bit) 연산을 수행함으로써 메모리 접근 효율을 극대화했습니다.
2. 환경 변수 제어 및 안정성
리뷰어들의 의견에 따라 이 기능은 기본적으로 비활성화되어 있으며, SGLANG_DIFFUSION_ENABLE_W8A8_FP8_GEMM 환경 변수를 통해 명시적으로 활성화해야 합니다. 이는 기존의 weight-only FP8 방식과 출력 결과가 미세하게 다를 수 있기 때문입니다.
왜 이게 좋은가
성능 개선 수치
- Latency: 5218.44 ms -> 3050.28 ms (약 41.5% 감소, 1.71x 속도 향상)
- Memory: Peak reserved memory가 약 822 MB 감소
교훈
- Dequantization 오버헤드 제거: 가중치를 연산 전에 dequantize하는 것은 메모리 대역폭을 낭비하는 주범입니다. Fused 커널을 사용하면 연산과 dequantization을 단일 단계로 처리하여 성능을 비약적으로 높일 수 있습니다.
- 점진적 도입(Opt-in): 수치적 이득이 크더라도 모델의 출력 품질에 영향을 줄 수 있는 최적화는 환경 변수를 통해 선택적으로 적용하는 것이 운영 안정성 측면에서 필수적입니다.
이번 최적화는 특히 H200과 같은 고성능 GPU 환경에서 Diffusion 모델의 추론 처리량을 크게 높여줄 것으로 기대됩니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.nn.functional.linear.html
- https://docs.sglang.ai/developer_guide/contribution_guide.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang의 SM120 FP8 Blockwise GEMM 성능 최적화: Pingpong 스케줄 도입
- [sglang] SGLang LTX-2 VAE 디코딩 성능 최적화: channels_last_3d 도입으로 4.5배 속도 향상
- [sglang] SGLang의 Spectral Progressive Diffusion 도입: 추론 속도 최대 2.78배 향상
- [sglang] SGLang의 Ideogram4 추론 성능 최적화: Denoising 루프 내 오버헤드 제거
- [sglang] [SGLang] LingBot 실시간 서빙 최적화: 카메라 컨디셔닝 캐싱과 전송 프로토콜 개선
PR Analysis 의 다른글
- 이전글 [sglang] ROCm 아키텍처별 최적화: 런타임 디스패치로 성능 극대화
- 현재글 : [sglang] SGLang Diffusion 모델의 FP8 GEMM 최적화: 41.5% 성능 향상 달성
- 다음글 [onnxruntime] WebGPU FlashAttention 최적화: 커널 퓨전과 가변 시퀀스 길이 지원으로 성능 극대화
댓글