본문으로 건너뛰기

[sglang] SGLang Diffusion 모델의 FP8 GEMM 최적화: 41.5% 성능 향상 달성

PR 링크: sgl-project/sglang#27590 상태: Merged | 변경: +153 / -15

들어가며

최근 대규모 생성형 모델, 특히 Diffusion 모델을 서빙할 때 연산 효율성은 매우 중요한 요소입니다. 기존 SGLang의 WeightOnlyFP8Linear 구현은 매 forward pass마다 FP8 가중치를 compute dtype(FP16/BF16)으로 dequantize한 뒤 F.linear를 호출하는 방식을 사용했습니다. 이 과정은 메모리 대역폭과 연산 오버헤드를 유발하여 전체 추론 속도를 저하시키는 병목이 되었습니다. 본 PR은 이 과정을 최적화하기 위해 기존의 Fused FP8 GEMM 커널을 활용하도록 개선했습니다.

코드 분석

1. weight_only_fp8.py의 핵심 로직 변경

가장 중요한 변경은 _apply_weight_only_fp8_linear 함수입니다. 기존에는 무조건 dequantize를 수행했으나, 이제는 환경 변수를 통해 Fused GEMM 사용 여부를 결정하고, 가능한 경우 커널을 직접 호출합니다.

# Before
def forward(self, x: torch.Tensor) -> torch.Tensor:
    compute_dtype = self.compute_dtype or x.dtype
    weight = dequantize_rowwise_fp8_weight(self.weight, self.weight_scale, compute_dtype)
    return F.linear(x.to(compute_dtype), weight, bias)

# After
def _apply_weight_only_fp8_linear(...):
    if enable_fused_w8a8 and _can_apply_fused_w8a8_fp8_linear(...):
        return _apply_srt_w8a8_fp8_linear(input=x, weight=weight.t(), ...)
    dequant_weight = dequantize_rowwise_fp8_weight(weight, weight_scale, compute_dtype)
    return F.linear(x, dequant_weight, bias)

이 변경을 통해 가중치를 미리 dequantize하지 않고, 커널 내부에서 W8A8(Weight 8-bit, Activation 8-bit) 연산을 수행함으로써 메모리 접근 효율을 극대화했습니다.

2. 환경 변수 제어 및 안정성

리뷰어들의 의견에 따라 이 기능은 기본적으로 비활성화되어 있으며, SGLANG_DIFFUSION_ENABLE_W8A8_FP8_GEMM 환경 변수를 통해 명시적으로 활성화해야 합니다. 이는 기존의 weight-only FP8 방식과 출력 결과가 미세하게 다를 수 있기 때문입니다.

왜 이게 좋은가

성능 개선 수치

  • Latency: 5218.44 ms -> 3050.28 ms (약 41.5% 감소, 1.71x 속도 향상)
  • Memory: Peak reserved memory가 약 822 MB 감소

교훈

  1. Dequantization 오버헤드 제거: 가중치를 연산 전에 dequantize하는 것은 메모리 대역폭을 낭비하는 주범입니다. Fused 커널을 사용하면 연산과 dequantization을 단일 단계로 처리하여 성능을 비약적으로 높일 수 있습니다.
  2. 점진적 도입(Opt-in): 수치적 이득이 크더라도 모델의 출력 품질에 영향을 줄 수 있는 최적화는 환경 변수를 통해 선택적으로 적용하는 것이 운영 안정성 측면에서 필수적입니다.

이번 최적화는 특히 H200과 같은 고성능 GPU 환경에서 Diffusion 모델의 추론 처리량을 크게 높여줄 것으로 기대됩니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글