본문으로 건너뛰기

[SGLang] FP4: 4비트 부동소수점 양자화 (NVIDIA NF4)

들어가며

FP8이 메모리를 절반으로 줄였다면, FP4는 한 번 더 절반을 줄인다. 4비트라는 극저정밀도에서 모델 정확도를 유지하기 위해 SGLang은 NVIDIA의 NVFP4 포맷과 다중 백엔드 GEMM 러너를 조합한다. python/sglang/srt/layers/quantization/fp4_utils.py에서 백엔드 선택 로직을, Compressed Tensors의 CompressedTensorsW4A4Fp4 스킴에서 실제 양자화 연산을 구현한다.

구조도

FP4 양자화 흐름
┌─────────────────────────────────────────────┐
│  fp4_utils.py                               │
│  ┌─────────────────────────┐                │
│  │ Fp4GemmRunnerBackend    │                │
│  │  ├── AUTO               │                │
│  │  ├── CUTLASS            │                │
│  │  ├── FLASHINFER_CUDNN   │                │
│  │  ├── FLASHINFER_CUTLASS │                │
│  │  └── FLASHINFER_TRTLLM  │                │
│  └─────────────────────────┘                │
│                                              │
│  initialize_fp4_gemm_config(server_args)    │
│    └── SM120 (Blackwell)? → flashinfer_cudnn│
│    └── else? → flashinfer_cutlass           │
└─────────────────────────────────────────────┘

CompressedTensorsW4A4Fp4 (스킴)
├── weight_packed: [N, K/2] uint8  (2개 FP4를 1바이트에 패킹)
├── weight_scale: [N, K/16] fp8_e4m3fn  (그룹별 스케일)
├── weight_global_scale: float32  (전역 스케일)
└── input_global_scale: float32   (입력 전역 스케일)

핵심 코드 분석

1. GEMM 백엔드 선택 Enum

FP4 GEMM은 하드웨어에 따라 다른 백엔드를 사용한다.

class Fp4GemmRunnerBackend(Enum):
    AUTO = "auto"
    CUTLASS = "cutlass"
    FLASHINFER_CUDNN = "flashinfer_cudnn"
    FLASHINFER_CUTLASS = "flashinfer_cutlass"
    FLASHINFER_TRTLLM = "flashinfer_trtllm"

    def get_flashinfer_backend(self) -> str:
        if self.value.startswith("flashinfer_"):
            return self.value.removeprefix("flashinfer_")
        else:
            return self.value

get_flashinfer_backend()는 SGLang의 백엔드 이름을 FlashInfer API 이름으로 변환한다. 예를 들어 flashinfer_trtllmtrtllm으로 매핑된다.

2. Blackwell 자동 감지

Blackwell(SM120) GPU에서는 cuDNN 백엔드가 안정적이므로 자동 선택된다.

def initialize_fp4_gemm_config(server_args: ServerArgs) -> None:
    global FP4_GEMM_RUNNER_BACKEND
    backend = server_args.fp4_gemm_runner_backend
    if backend == "auto":
        if is_sm120_supported():
            backend = "flashinfer_cudnn"
            logger.info(
                "SM120 (Blackwell) detected: auto-selecting "
                "fp4-gemm-backend=flashinfer_cudnn"
            )
        else:
            backend = "flashinfer_cutlass"
    FP4_GEMM_RUNNER_BACKEND = Fp4GemmRunnerBackend(backend)

이 결정의 배경은 flashinfer_cutlass가 SM120에서 이종 배치(heterogeneous batch)의 Dense MLP 레이어에서 NaN을 발생시키는 버그(#20043) 때문이다.

3. W4A4 가중치 생성 (CompressedTensorsW4A4Fp4)

FP4 가중치는 2개의 4비트 값을 1바이트(uint8)에 패킹하여 저장한다.

class CompressedTensorsW4A4Fp4(CompressedTensorsLinearScheme):
    def __init__(self):
        self.group_size = 16

    def create_weights(self, layer, output_partition_sizes,
                       input_size_per_partition, params_dtype,
                       weight_loader, **kwargs):
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition // 2,  # 2 FP4 per byte
                dtype=torch.uint8,
            ),
            input_dim=1, output_dim=0, weight_loader=weight_loader,
        )
        layer.register_parameter("weight_packed", weight)

입력 차원이 절반으로 줄어든 것은 2개의 FP4 값이 하나의 uint8에 패킹되기 때문이다.

4. 다중 스케일 체계

FP4는 정밀도가 매우 낮으므로, 전역 스케일과 그룹별 스케일을 함께 사용하여 정확도를 보존한다.

# 전역 가중치 스케일
weight_global_scale = PerTensorScaleParameter(
    data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
    weight_loader=weight_loader,
)
layer.register_parameter("weight_global_scale", weight_global_scale)

# 그룹별 가중치 스케일 (16개 원소마다 1개)
weight_scale = GroupQuantScaleParameter(
    data=torch.empty(
        sum(output_partition_sizes),
        input_size_per_partition // self.group_size,
        dtype=torch.float8_e4m3fn,
    ),
    input_dim=1, output_dim=0, weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)

그룹별 스케일은 torch.float8_e4m3fn 타입으로 저장되어 스케일 자체도 저정밀도를 유지한다.

5. 가중치 후처리: 레이아웃 변환

백엔드에 따라 가중치와 스케일의 메모리 레이아웃을 변환한다.

def process_weights_after_loading(self, layer) -> None:
    layer.alpha = Parameter(
        1 / (layer.input_global_scale * layer.weight_global_scale),
        requires_grad=False,
    )

    if get_fp4_gemm_runner_backend().is_flashinfer_trtllm():
        weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
        weight_scale = shuffle_matrix_sf_a(
            weight_scale.view(torch.uint8), epilogue_tile_m
        ).reshape(weight_scale.shape).view(torch.float8_e4m3fn)
    else:
        swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)

alpha1 / (input_global_scale * weight_global_scale)로 계산되어, GEMM 출력을 원래 스케일로 복원하는 역할을 한다.

6. FP4 추론 경로

추론 시에는 입력을 FP4로 양자화한 뒤 FP4 GEMM 커널을 호출한다.

def apply_weights(self, layer, x, bias=None):
    x_fp4, x_blockscale = fp4_quantize(x, layer.input_global_scale)
    out = fp4_gemm(
        x_fp4, w, x_blockscale, w_blockscale,
        layer.alpha, output_dtype, w_n,
    )
    if bias is not None:
        out = out + bias
    return out.view(*output_shape)

fp4_quantize는 BF16/FP16 입력을 FP4와 블록 스케일로 변환한다. fp4_gemm은 백엔드(CUTLASS/FlashInfer)에 따라 최적화된 커널을 실행한다.

FP8 vs FP4 비교

항목 FP8 (E4M3) FP4 (NVFP4)
비트 수 8 4
가중치 메모리 0.5x (FP16 대비) 0.25x
스케일 구조 텐서/블록별 float32 전역 float32 + 그룹별 fp8
그룹 크기 128 (블록) 16
최소 GPU SM80 (A100) SM100 (Blackwell)
활성화 양자화 FP8 FP4 (W4A4)
패킹 1:1 2개 FP4 → 1 uint8
백엔드 Triton/DeepGEMM CUTLASS/FlashInfer cuDNN

설계 근거

  1. 이중 스케일 체계: 4비트는 표현 범위가 극히 제한적이므로, 전역 스케일로 큰 범위를 맞추고 그룹별 스케일로 세밀한 조정을 한다.
  2. SM120 우선 설계: FP4 GEMM은 Blackwell의 FP4 Tensor Core에 최적화되어 있으며, SM100 미만에서는 지원하지 않는다.
  3. 패킹 최적화: 2개의 FP4를 1바이트에 패킹하여 메모리 대역폭 활용을 극대화한다.
  4. 백엔드 자동 전환: Blackwell에서 CUTLASS가 불안정한 경우를 감지하여 cuDNN으로 자동 폴백한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글