본문으로 건너뛰기

[sglang] Cutlass FP8 Blockwise GEMM 최적화: 불필요한 패딩 제거로 GPU 성능 향상

PR 링크: sgl-project/sglang#27896 상태: Merged | 변경: +202 / -3

들어가며

딥러닝 모델의 추론 성능은 GPU 자원을 얼마나 효율적으로 사용하는지에 따라 크게 달라집니다. 특히 대규모 언어 모델(LLM)에서 FP8(8비트 부동소수점) 양자화는 메모리 사용량과 계산 속도 면에서 큰 이점을 제공하지만, 이를 효율적으로 구현하는 것은 복잡한 최적화 작업을 요구합니다. 오늘 분석할 sgl-project/sglang 레포지토리의 PR은 Cutlass FP8 Blockwise GEMM(General Matrix Multiply) 경로에서 발생하는 비효율적인 패딩(padding) 문제를 해결하여 GPU 성능을 개선한 사례입니다.

기존 cutlass_w8a8_block_fp8_linear_with_fallback 함수는 액티베이션(activation)을 양자화한 후 fp8_blockwise_scaled_mm을 호출하는데, 이 과정에서 mat_a (양자화된 액티베이션)와 scales_a (스케일 팩터)를 매번 4의 배수 행(row)으로 패딩하는 비효율적인 작업을 수행했습니다. 이 패딩은 torch.zerostorch.cat 커널을 각각 두 번씩 호출하여 불필요한 GPU 오버헤드를 발생시켰습니다. 특히 M(행 수) 값이 1에서 5 사이로 작은 speculative decoding 시나리오에서는 이 오버헤드가 더욱 두드러졌습니다. 이 PR은 이러한 반복적인 패딩을 제거하여 성능을 최적화합니다.

코드 분석

이 PR의 핵심 아이디어는 패딩을 매 GEMM 호출 직전에 수행하는 대신, 양자화 단계에서 미리 패딩된 버퍼를 할당하는 것입니다. 이를 통해 pad_tensor() 함수가 pad_rows == 0 조건을 만족하여 불필요한 커널 호출을 건너뛰게 됩니다.

python/sglang/srt/layers/quantization/fp8_kernel.py

이 파일에서는 sglang_per_token_group_quant_fp8_row_padded라는 새로운 함수가 추가되었습니다. 이 함수는 기존 sglang_per_token_group_quant_fp8와 유사하게 동작하지만, 양자화된 액티베이션 x_q와 스케일 x_s를 반환하기 전에 미리 4의 배수로 정렬된(row-padded) 버퍼를 할당합니다.

Before:

# 기존 sglang_per_token_group_quant_fp8 함수는 패딩되지 않은 버퍼를 반환합니다.
def sglang_per_token_group_quant_fp8(
    x: torch.Tensor,
    group_size: int,
    eps: float = 1e-10,
    column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # ... (생략) ...
    x_q = torch.empty(x.shape, device=x.device, dtype=fp8_dtype)
    x_s = torch.empty(
        (k // group_size, m) if column_major_scales else (m, k // group_size),
        device=x.device, dtype=torch.float32
    )
    # ... (생략) ...
    return x_q, x_s

After:

--- a/python/sglang/srt/layers/quantization/fp8_kernel.py
+++ b/python/sglang/srt/layers/quantization/fp8_kernel.py
@@ -565,6 +565,59 @@ def sglang_per_token_group_quant_fp8(
     return x_q, x_s
 
 
+def sglang_per_token_group_quant_fp8_row_padded(
+    x: torch.Tensor,
+    group_size: int,
+    eps: float = 1e-10,
+    row_alignment: int = 4,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Per-token-group quant writing into row-padded buffers (col-major scales).
+
+    The cutlass fp8_blockwise_scaled_mm wrapper pads mat_a / scales_a to a
+    multiple of 4 rows on every call (a zeros fill + a cat for each of mat_a
+    and scales_a). Allocating the quant outputs with rows already aligned to
+    ``row_alignment`` makes the wrapper's pad_tensor() short-circuit (pad_rows
+    == 0), removing 2x fill + 2x cat kernels per GEMM. Rows in [m, m_pad) are
+    uninitialized garbage; the caller must slice the GEMM output back to m.
+    """
+    assert x.dim() == 2, "row-padded quant expects a 2D input"
+    assert (
+        x.shape[-1] % group_size == 0
+    ), "the last dimension of `x` must be divisible by `group_size`"
+    assert x.is_contiguous(), "`x` is not contiguous"
+
+    if not (enable_sgl_per_token_group_quant_8bit and group_size in (16, 32, 64, 128)):
+        # No v2 kernel available: keep the legacy unpadded path and let the
+        # GEMM wrapper do the padding.
+        return sglang_per_token_group_quant_fp8(
+            x, group_size, eps, column_major_scales=True
+        )
+
+    m, k = x.shape
+    m_pad = ceil_align(m, row_alignment)
+    # mat_a buffer: (m_pad, k) row-major fp8
+    x_q = torch.empty((m_pad, k), device=x.device, dtype=fp8_dtype)
+    # scales_a buffer: column-major (stride(0) == 1), shape (m_pad, k // group)
+    x_s = torch.empty(
+        (k // group_size, m_pad), device=x.device, dtype=torch.float32
+    ).transpose(0, 1)
+    if m > 0:
+        sgl_per_token_group_quant_8bit(
+            x,
+            x_q[:m],
+            x_s[:m],
+            group_size,
+            eps,
+            fp8_min,
+            fp8_max,
+            False,  # scale_ue8m0
+            False,  # fuse_silu_and_mul
+            None,  # masked_m
+            enable_v2=True,
+        )
+    return x_q, x_s

새로운 sglang_per_token_group_quant_fp8_row_padded 함수는 m_pad = ceil_align(m, row_alignment)를 계산하여 x_qx_sm_pad 크기로 미리 할당합니다. 특히 x_s는 Cutlass 커널이 요구하는 M-major 레이아웃(stride(0) == 1)을 위해 transpose(0, 1)를 수행합니다. 실제 양자화는 x_q[:m]x_s[:m] 슬라이스에 대해서만 이루어지며, [m, m_pad) 범위의 행은 의도적으로 초기화되지 않은 상태로 남겨둡니다. 이는 GEMM이 행 독립적(row-independent)이므로, 이 부분은 결과에 영향을 미치지 않기 때문입니다.

python/sglang/srt/layers/quantization/fp8_utils.py

이 파일에서는 cutlass_w8a8_block_fp8_linear_with_fallback 함수가 수정되어 새로 추가된 sglang_per_token_group_quant_fp8_row_padded를 사용하고, GEMM 결과에서 패딩된 부분을 다시 잘라내는 로직이 추가되었습니다.

Before:

# 기존 cutlass_w8a8_block_fp8_linear_with_fallback 함수는 unpadded quant를 호출합니다.
def cutlass_w8a8_block_fp8_linear_with_fallback(
    input: torch.Tensor,
    weight: torch.Tensor,
    block_size: Tuple[int, int],
    weight_scale: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    # ... (생략) ...
    q_input, x_scale = per_token_group_quant_fp8(
        input_2d, block_size[1], column_major_scales=True
    )
    output = fp8_blockwise_scaled_mm(
        q_input, weight.T, x_scale, weight_scale.T, out_dtype=input_2d.dtype
    )
    if bias is not None:
        output += bias
    return output.to(dtype=input_2d.dtype).view(*output_shape)

After:

--- a/python/sglang/srt/layers/quantization/fp8_utils.py
+++ b/python/sglang/srt/layers/quantization/fp8_utils.py
@@ -8,7 +8,10 @@
 import torch
 
 from sglang.srt.layers import deep_gemm_wrapper
-from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
+from sglang.srt.layers.quantization.fp8_kernel import (
+    sglang_per_token_group_quant_fp8,
+    sglang_per_token_group_quant_fp8_row_padded,
+)
 from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
 from sglang.srt.utils.common import torch_release
 
@@ -635,12 +638,19 @@ def cutlass_w8a8_block_fp8_linear_with_fallback(
     input_2d = input.view(-1, input.shape[-1])
     output_shape = [*input.shape[:-1], weight.shape[0]]
 
-    q_input, x_scale = per_token_group_quant_fp8(
-        input_2d, block_size[1], column_major_scales=True
+    # Quantize into row-padded buffers so the sgl-kernel wrapper's per-call
+    # pad_tensor() on mat_a / scales_a short-circuits (saves 2x fill + 2x cat
+    # kernels per GEMM). weight_scale.T is left as a K-major view because the
+    # kernel requires scales_b.stride(0) == 1 and materializes it internally.
+    q_input, x_scale = sglang_per_token_group_quant_fp8_row_padded(
+        input_2d, block_size[1]
     )
     output = fp8_blockwise_scaled_mm(
         q_input, weight.T, x_scale, weight_scale.T, out_dtype=input_2d.dtype
     )
+    if output.shape[0] != input_2d.shape[0]:
+        # GEMM ran on the row-padded buffer; drop the padding rows.
+        output = output[: input_2d.shape[0]]
     if bias is not None:
         output += bias
     return output.to(dtype=input_2d.dtype).view(*output_shape)

이제 q_input, x_scale을 얻기 위해 sglang_per_token_group_quant_fp8_row_padded를 호출합니다. GEMM 연산 후에는 output[: input_2d.shape[0]]를 통해 원래 입력의 행 수만큼만 결과를 슬라이싱하여 패딩된 부분을 제거합니다. weight_scale.T는 커널이 scales_b.stride(0) == 1을 요구하고 내부적으로 자체적인 contiguous copy를 만들기 때문에 변경하지 않습니다.

test/registered/quant/test_fp8_blockwise_row_padding.py

이 PR은 새로운 유닛 테스트 파일 test_fp8_blockwise_row_padding.py를 추가하여 변경사항의 정확성을 검증합니다. 주요 테스트 내용은 다음과 같습니다:

  • test_quant_buffers_row_aligned: row_padded 양자화 함수가 4-aligned, M-major 버퍼를 반환하고, 실제 데이터 부분은 기존 양자화 결과와 비트 단위로 일치하는지 확인합니다.
  • test_gemm_bit_exact_vs_legacy: 전체 선형 연산(row-padded 경로)이 기존 unpadded GEMM과 비트 단위로 동일한 결과를 내는지 검증합니다. 다양한 M 값(_M_VALUES = [1, 2, 3, 4, 5, 7, 13, 16, 31, 64, 256])에 대해 테스트하여 모든 시나리오를 커버합니다.
  • test_linear_matches_bf16_reference: FP8 선형 연산 결과가 bf16 참조 행렬 곱셈 결과와 합리적인 오차 범위 내에 있는지 확인하여 수치적 안정성을 검증합니다.

왜 이게 좋은가

이 최적화는 불필요한 GPU 커널 호출을 제거함으로써 성능을 크게 향상시킵니다. 특히 LLM 추론과 같이 반복적인 GEMM 연산이 많은 워크로드에서 그 효과가 두드러집니다.

성능 수치:

H20-3e, TP=4, Qwen3.5-122B-A10B-FP8, NEXTN speculative decoding 환경에서 측정된 결과는 다음과 같습니다.

Kernel family Baseline This PR
cutlass FP8 GEMM 23.4k launches, 454 ms 23.4k launches, 454 ms
group quant 35.0k launches, 107 ms 35.0k launches, 107 ms
fill ~65k launches, ~69 ms 315 launches, 0.4 ms
cat ~63k launches, ~89 ms 364 launches, 0.7 ms

위 표에서 볼 수 있듯이, fillcat 커널의 호출 횟수와 GPU 시간이 거의 완전히 제거되었습니다. 이는 약 158ms의 GPU 시간을 절약하는 효과를 가져왔습니다.

End-to-end 성능:

Workload Baseline This PR Delta
bs=1 single request, latency 2.21 s 2.21 s no change
concurrency 64, output throughput 2617 tok/s 2671 tok/s +2.1%

단일 요청(bs=1)의 Latency는 CPU-bound 특성으로 인해 변화가 없었지만, GPU-bound 환경인 concurrency 64에서는 +2.1%의 처리량(throughput) 향상을 달성했습니다. 이는 GPU가 최대로 활용되는 상황에서 불필요한 커널 오버헤드를 줄이는 것이 실제 성능 개선으로 이어진다는 것을 보여줍니다.

일반적 교훈:

  1. 하드웨어/라이브러리 요구사항의 조기 처리: Cutlass SM90 블록와이즈 커널이 A operand의 행 수가 4의 배수여야 한다는 요구사항을 매번 동적으로 처리하는 대신, 양자화 단계에서 미리 버퍼를 정렬함으로써 불필요한 작업을 제거했습니다. 이는 하드웨어 또는 라이브러리 제약을 미리 파악하고 초기 단계에서 해결하는 것이 중요함을 시사합니다.
  2. 메모리 할당 전략 최적화: torch.zerostorch.cat와 같은 메모리 조작 커널은 비용이 많이 들 수 있습니다. 필요한 메모리를 한 번에 올바른 크기와 레이아웃으로 할당하는 것은 반복적인 메모리 복사 및 재할당을 피하는 효과적인 방법입니다.
  3. 작은 M 값에서의 오버헤드 민감성: speculative decoding과 같이 M 값이 작은 경우, 매 GEMM 호출마다 발생하는 작은 오버헤드라도 전체 성능에 큰 영향을 미칠 수 있습니다. 이러한 시나리오에서는 미세한 최적화가 중요합니다.
  4. CPU-bound vs GPU-bound 분석: 최적화의 효과는 시스템의 병목 지점에 따라 다르게 나타납니다. 이 PR의 경우, CPU-bound 워크로드에서는 효과가 미미했지만, GPU-bound 워크로드에서는 명확한 성능 향상을 보였습니다. 이는 성능 분석 시 시스템의 병목 지점을 정확히 파악하는 것이 중요함을 알려줍니다.

리뷰 피드백 반영

PR 리뷰 과정에서 python/sglang/srt/layers/quantization/fp8_kernel.py 파일에 대한 yuan-luo님의 "Good catch, fixed."라는 코멘트가 있었습니다. PR 설명에 따르면 scales_a 버퍼는 커널이 요구하는 M-major 레이아웃(stride(0) == 1)을 가져야 합니다. 초기 구현에서 이 부분이 제대로 처리되지 않았을 가능성이 있으며, 리뷰를 통해 수정되어 transpose(0, 1) 연산을 통해 올바른 레이아웃을 갖도록 보완되었을 것으로 추정됩니다. 이는 코드의 정확성과 커널 요구사항 준수를 위한 중요한 피드백이었습니다.

이 PR은 SGLang의 FP8 추론 경로에서 중요한 성능 개선을 이루었으며, 저수준 GPU 프로그래밍과 최적화의 중요성을 잘 보여주는 사례입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글