[sglang] Cutlass FP8 Blockwise GEMM 최적화: 불필요한 패딩 제거로 GPU 성능 향상
PR 링크: sgl-project/sglang#27896 상태: Merged | 변경: +202 / -3
들어가며
딥러닝 모델의 추론 성능은 GPU 자원을 얼마나 효율적으로 사용하는지에 따라 크게 달라집니다. 특히 대규모 언어 모델(LLM)에서 FP8(8비트 부동소수점) 양자화는 메모리 사용량과 계산 속도 면에서 큰 이점을 제공하지만, 이를 효율적으로 구현하는 것은 복잡한 최적화 작업을 요구합니다. 오늘 분석할 sgl-project/sglang 레포지토리의 PR은 Cutlass FP8 Blockwise GEMM(General Matrix Multiply) 경로에서 발생하는 비효율적인 패딩(padding) 문제를 해결하여 GPU 성능을 개선한 사례입니다.
기존 cutlass_w8a8_block_fp8_linear_with_fallback 함수는 액티베이션(activation)을 양자화한 후 fp8_blockwise_scaled_mm을 호출하는데, 이 과정에서 mat_a (양자화된 액티베이션)와 scales_a (스케일 팩터)를 매번 4의 배수 행(row)으로 패딩하는 비효율적인 작업을 수행했습니다. 이 패딩은 torch.zeros와 torch.cat 커널을 각각 두 번씩 호출하여 불필요한 GPU 오버헤드를 발생시켰습니다. 특히 M(행 수) 값이 1에서 5 사이로 작은 speculative decoding 시나리오에서는 이 오버헤드가 더욱 두드러졌습니다. 이 PR은 이러한 반복적인 패딩을 제거하여 성능을 최적화합니다.
코드 분석
이 PR의 핵심 아이디어는 패딩을 매 GEMM 호출 직전에 수행하는 대신, 양자화 단계에서 미리 패딩된 버퍼를 할당하는 것입니다. 이를 통해 pad_tensor() 함수가 pad_rows == 0 조건을 만족하여 불필요한 커널 호출을 건너뛰게 됩니다.
python/sglang/srt/layers/quantization/fp8_kernel.py
이 파일에서는 sglang_per_token_group_quant_fp8_row_padded라는 새로운 함수가 추가되었습니다. 이 함수는 기존 sglang_per_token_group_quant_fp8와 유사하게 동작하지만, 양자화된 액티베이션 x_q와 스케일 x_s를 반환하기 전에 미리 4의 배수로 정렬된(row-padded) 버퍼를 할당합니다.
Before:
# 기존 sglang_per_token_group_quant_fp8 함수는 패딩되지 않은 버퍼를 반환합니다.
def sglang_per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
# ... (생략) ...
x_q = torch.empty(x.shape, device=x.device, dtype=fp8_dtype)
x_s = torch.empty(
(k // group_size, m) if column_major_scales else (m, k // group_size),
device=x.device, dtype=torch.float32
)
# ... (생략) ...
return x_q, x_s
After:
--- a/python/sglang/srt/layers/quantization/fp8_kernel.py
+++ b/python/sglang/srt/layers/quantization/fp8_kernel.py
@@ -565,6 +565,59 @@ def sglang_per_token_group_quant_fp8(
return x_q, x_s
+def sglang_per_token_group_quant_fp8_row_padded(
+ x: torch.Tensor,
+ group_size: int,
+ eps: float = 1e-10,
+ row_alignment: int = 4,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Per-token-group quant writing into row-padded buffers (col-major scales).
+
+ The cutlass fp8_blockwise_scaled_mm wrapper pads mat_a / scales_a to a
+ multiple of 4 rows on every call (a zeros fill + a cat for each of mat_a
+ and scales_a). Allocating the quant outputs with rows already aligned to
+ ``row_alignment`` makes the wrapper's pad_tensor() short-circuit (pad_rows
+ == 0), removing 2x fill + 2x cat kernels per GEMM. Rows in [m, m_pad) are
+ uninitialized garbage; the caller must slice the GEMM output back to m.
+ """
+ assert x.dim() == 2, "row-padded quant expects a 2D input"
+ assert (
+ x.shape[-1] % group_size == 0
+ ), "the last dimension of `x` must be divisible by `group_size`"
+ assert x.is_contiguous(), "`x` is not contiguous"
+
+ if not (enable_sgl_per_token_group_quant_8bit and group_size in (16, 32, 64, 128)):
+ # No v2 kernel available: keep the legacy unpadded path and let the
+ # GEMM wrapper do the padding.
+ return sglang_per_token_group_quant_fp8(
+ x, group_size, eps, column_major_scales=True
+ )
+
+ m, k = x.shape
+ m_pad = ceil_align(m, row_alignment)
+ # mat_a buffer: (m_pad, k) row-major fp8
+ x_q = torch.empty((m_pad, k), device=x.device, dtype=fp8_dtype)
+ # scales_a buffer: column-major (stride(0) == 1), shape (m_pad, k // group)
+ x_s = torch.empty(
+ (k // group_size, m_pad), device=x.device, dtype=torch.float32
+ ).transpose(0, 1)
+ if m > 0:
+ sgl_per_token_group_quant_8bit(
+ x,
+ x_q[:m],
+ x_s[:m],
+ group_size,
+ eps,
+ fp8_min,
+ fp8_max,
+ False, # scale_ue8m0
+ False, # fuse_silu_and_mul
+ None, # masked_m
+ enable_v2=True,
+ )
+ return x_q, x_s
새로운 sglang_per_token_group_quant_fp8_row_padded 함수는 m_pad = ceil_align(m, row_alignment)를 계산하여 x_q와 x_s를 m_pad 크기로 미리 할당합니다. 특히 x_s는 Cutlass 커널이 요구하는 M-major 레이아웃(stride(0) == 1)을 위해 transpose(0, 1)를 수행합니다. 실제 양자화는 x_q[:m]과 x_s[:m] 슬라이스에 대해서만 이루어지며, [m, m_pad) 범위의 행은 의도적으로 초기화되지 않은 상태로 남겨둡니다. 이는 GEMM이 행 독립적(row-independent)이므로, 이 부분은 결과에 영향을 미치지 않기 때문입니다.
python/sglang/srt/layers/quantization/fp8_utils.py
이 파일에서는 cutlass_w8a8_block_fp8_linear_with_fallback 함수가 수정되어 새로 추가된 sglang_per_token_group_quant_fp8_row_padded를 사용하고, GEMM 결과에서 패딩된 부분을 다시 잘라내는 로직이 추가되었습니다.
Before:
# 기존 cutlass_w8a8_block_fp8_linear_with_fallback 함수는 unpadded quant를 호출합니다.
def cutlass_w8a8_block_fp8_linear_with_fallback(
input: torch.Tensor,
weight: torch.Tensor,
block_size: Tuple[int, int],
weight_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# ... (생략) ...
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=True
)
output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input_2d.dtype
)
if bias is not None:
output += bias
return output.to(dtype=input_2d.dtype).view(*output_shape)
After:
--- a/python/sglang/srt/layers/quantization/fp8_utils.py
+++ b/python/sglang/srt/layers/quantization/fp8_utils.py
@@ -8,7 +8,10 @@
import torch
from sglang.srt.layers import deep_gemm_wrapper
-from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
+from sglang.srt.layers.quantization.fp8_kernel import (
+ sglang_per_token_group_quant_fp8,
+ sglang_per_token_group_quant_fp8_row_padded,
+)
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.utils.common import torch_release
@@ -635,12 +638,19 @@ def cutlass_w8a8_block_fp8_linear_with_fallback(
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
- q_input, x_scale = per_token_group_quant_fp8(
- input_2d, block_size[1], column_major_scales=True
+ # Quantize into row-padded buffers so the sgl-kernel wrapper's per-call
+ # pad_tensor() on mat_a / scales_a short-circuits (saves 2x fill + 2x cat
+ # kernels per GEMM). weight_scale.T is left as a K-major view because the
+ # kernel requires scales_b.stride(0) == 1 and materializes it internally.
+ q_input, x_scale = sglang_per_token_group_quant_fp8_row_padded(
+ input_2d, block_size[1]
)
output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input_2d.dtype
)
+ if output.shape[0] != input_2d.shape[0]:
+ # GEMM ran on the row-padded buffer; drop the padding rows.
+ output = output[: input_2d.shape[0]]
if bias is not None:
output += bias
return output.to(dtype=input_2d.dtype).view(*output_shape)
이제 q_input, x_scale을 얻기 위해 sglang_per_token_group_quant_fp8_row_padded를 호출합니다. GEMM 연산 후에는 output[: input_2d.shape[0]]를 통해 원래 입력의 행 수만큼만 결과를 슬라이싱하여 패딩된 부분을 제거합니다. weight_scale.T는 커널이 scales_b.stride(0) == 1을 요구하고 내부적으로 자체적인 contiguous copy를 만들기 때문에 변경하지 않습니다.
test/registered/quant/test_fp8_blockwise_row_padding.py
이 PR은 새로운 유닛 테스트 파일 test_fp8_blockwise_row_padding.py를 추가하여 변경사항의 정확성을 검증합니다. 주요 테스트 내용은 다음과 같습니다:
test_quant_buffers_row_aligned:row_padded양자화 함수가 4-aligned, M-major 버퍼를 반환하고, 실제 데이터 부분은 기존 양자화 결과와 비트 단위로 일치하는지 확인합니다.test_gemm_bit_exact_vs_legacy: 전체 선형 연산(row-padded 경로)이 기존 unpadded GEMM과 비트 단위로 동일한 결과를 내는지 검증합니다. 다양한 M 값(_M_VALUES = [1, 2, 3, 4, 5, 7, 13, 16, 31, 64, 256])에 대해 테스트하여 모든 시나리오를 커버합니다.test_linear_matches_bf16_reference: FP8 선형 연산 결과가 bf16 참조 행렬 곱셈 결과와 합리적인 오차 범위 내에 있는지 확인하여 수치적 안정성을 검증합니다.
왜 이게 좋은가
이 최적화는 불필요한 GPU 커널 호출을 제거함으로써 성능을 크게 향상시킵니다. 특히 LLM 추론과 같이 반복적인 GEMM 연산이 많은 워크로드에서 그 효과가 두드러집니다.
성능 수치:
H20-3e, TP=4, Qwen3.5-122B-A10B-FP8, NEXTN speculative decoding 환경에서 측정된 결과는 다음과 같습니다.
| Kernel family | Baseline | This PR |
|---|---|---|
| cutlass FP8 GEMM | 23.4k launches, 454 ms | 23.4k launches, 454 ms |
| group quant | 35.0k launches, 107 ms | 35.0k launches, 107 ms |
| fill | ~65k launches, ~69 ms | 315 launches, 0.4 ms |
| cat | ~63k launches, ~89 ms | 364 launches, 0.7 ms |
위 표에서 볼 수 있듯이, fill과 cat 커널의 호출 횟수와 GPU 시간이 거의 완전히 제거되었습니다. 이는 약 158ms의 GPU 시간을 절약하는 효과를 가져왔습니다.
End-to-end 성능:
| Workload | Baseline | This PR | Delta |
|---|---|---|---|
| bs=1 single request, latency | 2.21 s | 2.21 s | no change |
| concurrency 64, output throughput | 2617 tok/s | 2671 tok/s | +2.1% |
단일 요청(bs=1)의 Latency는 CPU-bound 특성으로 인해 변화가 없었지만, GPU-bound 환경인 concurrency 64에서는 +2.1%의 처리량(throughput) 향상을 달성했습니다. 이는 GPU가 최대로 활용되는 상황에서 불필요한 커널 오버헤드를 줄이는 것이 실제 성능 개선으로 이어진다는 것을 보여줍니다.
일반적 교훈:
- 하드웨어/라이브러리 요구사항의 조기 처리: Cutlass SM90 블록와이즈 커널이 A operand의 행 수가 4의 배수여야 한다는 요구사항을 매번 동적으로 처리하는 대신, 양자화 단계에서 미리 버퍼를 정렬함으로써 불필요한 작업을 제거했습니다. 이는 하드웨어 또는 라이브러리 제약을 미리 파악하고 초기 단계에서 해결하는 것이 중요함을 시사합니다.
- 메모리 할당 전략 최적화:
torch.zeros와torch.cat와 같은 메모리 조작 커널은 비용이 많이 들 수 있습니다. 필요한 메모리를 한 번에 올바른 크기와 레이아웃으로 할당하는 것은 반복적인 메모리 복사 및 재할당을 피하는 효과적인 방법입니다. - 작은 M 값에서의 오버헤드 민감성: speculative decoding과 같이 M 값이 작은 경우, 매 GEMM 호출마다 발생하는 작은 오버헤드라도 전체 성능에 큰 영향을 미칠 수 있습니다. 이러한 시나리오에서는 미세한 최적화가 중요합니다.
- CPU-bound vs GPU-bound 분석: 최적화의 효과는 시스템의 병목 지점에 따라 다르게 나타납니다. 이 PR의 경우, CPU-bound 워크로드에서는 효과가 미미했지만, GPU-bound 워크로드에서는 명확한 성능 향상을 보였습니다. 이는 성능 분석 시 시스템의 병목 지점을 정확히 파악하는 것이 중요함을 알려줍니다.
리뷰 피드백 반영
PR 리뷰 과정에서 python/sglang/srt/layers/quantization/fp8_kernel.py 파일에 대한 yuan-luo님의 "Good catch, fixed."라는 코멘트가 있었습니다. PR 설명에 따르면 scales_a 버퍼는 커널이 요구하는 M-major 레이아웃(stride(0) == 1)을 가져야 합니다. 초기 구현에서 이 부분이 제대로 처리되지 않았을 가능성이 있으며, 리뷰를 통해 수정되어 transpose(0, 1) 연산을 통해 올바른 레이아웃을 갖도록 보완되었을 것으로 추정됩니다. 이는 코드의 정확성과 커널 요구사항 준수를 위한 중요한 피드백이었습니다.
이 PR은 SGLang의 FP8 추론 경로에서 중요한 성능 개선을 이루었으며, 저수준 GPU 프로그래밍과 최적화의 중요성을 잘 보여주는 사례입니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.cat.html
- https://pytorch.org/docs/stable/generated/torch.zeros.html
- https://pytorch.org/docs/stable/generated/torch.empty.html
- https://pytorch.org/docs/stable/generated/torch.transpose.html
- https://pytorch.org/docs/stable/generated/torch.Tensor.slice.html
- https://github.com/NVIDIA/cutlass
- https://pytorch.org/blog/pytorch-fp8/
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [vllm] vLLM에서 Lfm2VL 모델을 위한 Encoder CUDA Graph 최적화 적용
- 현재글 : [sglang] Cutlass FP8 Blockwise GEMM 최적화: 불필요한 패딩 제거로 GPU 성능 향상
- 다음글 [triton] Triton AMD StreamK GEMM 커널의 Race Condition 해결: 동기화 로직 최적화 분석
댓글