[SGLang] FP4: 4비트 부동소수점 양자화 (NVIDIA NF4)
들어가며
FP8이 메모리를 절반으로 줄였다면, FP4는 한 번 더 절반을 줄인다. 4비트라는 극저정밀도에서 모델 정확도를 유지하기 위해 SGLang은 NVIDIA의 NVFP4 포맷과 다중 백엔드 GEMM 러너를 조합한다. python/sglang/srt/layers/quantization/fp4_utils.py에서 백엔드 선택 로직을, Compressed Tensors의 CompressedTensorsW4A4Fp4 스킴에서 실제 양자화 연산을 구현한다.
구조도
FP4 양자화 흐름
┌─────────────────────────────────────────────┐
│ fp4_utils.py │
│ ┌─────────────────────────┐ │
│ │ Fp4GemmRunnerBackend │ │
│ │ ├── AUTO │ │
│ │ ├── CUTLASS │ │
│ │ ├── FLASHINFER_CUDNN │ │
│ │ ├── FLASHINFER_CUTLASS │ │
│ │ └── FLASHINFER_TRTLLM │ │
│ └─────────────────────────┘ │
│ │
│ initialize_fp4_gemm_config(server_args) │
│ └── SM120 (Blackwell)? → flashinfer_cudnn│
│ └── else? → flashinfer_cutlass │
└─────────────────────────────────────────────┘
CompressedTensorsW4A4Fp4 (스킴)
├── weight_packed: [N, K/2] uint8 (2개 FP4를 1바이트에 패킹)
├── weight_scale: [N, K/16] fp8_e4m3fn (그룹별 스케일)
├── weight_global_scale: float32 (전역 스케일)
└── input_global_scale: float32 (입력 전역 스케일)
핵심 코드 분석
1. GEMM 백엔드 선택 Enum
FP4 GEMM은 하드웨어에 따라 다른 백엔드를 사용한다.
class Fp4GemmRunnerBackend(Enum):
AUTO = "auto"
CUTLASS = "cutlass"
FLASHINFER_CUDNN = "flashinfer_cudnn"
FLASHINFER_CUTLASS = "flashinfer_cutlass"
FLASHINFER_TRTLLM = "flashinfer_trtllm"
def get_flashinfer_backend(self) -> str:
if self.value.startswith("flashinfer_"):
return self.value.removeprefix("flashinfer_")
else:
return self.value
get_flashinfer_backend()는 SGLang의 백엔드 이름을 FlashInfer API 이름으로 변환한다. 예를 들어 flashinfer_trtllm은 trtllm으로 매핑된다.
2. Blackwell 자동 감지
Blackwell(SM120) GPU에서는 cuDNN 백엔드가 안정적이므로 자동 선택된다.
def initialize_fp4_gemm_config(server_args: ServerArgs) -> None:
global FP4_GEMM_RUNNER_BACKEND
backend = server_args.fp4_gemm_runner_backend
if backend == "auto":
if is_sm120_supported():
backend = "flashinfer_cudnn"
logger.info(
"SM120 (Blackwell) detected: auto-selecting "
"fp4-gemm-backend=flashinfer_cudnn"
)
else:
backend = "flashinfer_cutlass"
FP4_GEMM_RUNNER_BACKEND = Fp4GemmRunnerBackend(backend)
이 결정의 배경은 flashinfer_cutlass가 SM120에서 이종 배치(heterogeneous batch)의 Dense MLP 레이어에서 NaN을 발생시키는 버그(#20043) 때문이다.
3. W4A4 가중치 생성 (CompressedTensorsW4A4Fp4)
FP4 가중치는 2개의 4비트 값을 1바이트(uint8)에 패킹하여 저장한다.
class CompressedTensorsW4A4Fp4(CompressedTensorsLinearScheme):
def __init__(self):
self.group_size = 16
def create_weights(self, layer, output_partition_sizes,
input_size_per_partition, params_dtype,
weight_loader, **kwargs):
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2, # 2 FP4 per byte
dtype=torch.uint8,
),
input_dim=1, output_dim=0, weight_loader=weight_loader,
)
layer.register_parameter("weight_packed", weight)
입력 차원이 절반으로 줄어든 것은 2개의 FP4 값이 하나의 uint8에 패킹되기 때문이다.
4. 다중 스케일 체계
FP4는 정밀도가 매우 낮으므로, 전역 스케일과 그룹별 스케일을 함께 사용하여 정확도를 보존한다.
# 전역 가중치 스케일
weight_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_global_scale", weight_global_scale)
# 그룹별 가중치 스케일 (16개 원소마다 1개)
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn,
),
input_dim=1, output_dim=0, weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
그룹별 스케일은 torch.float8_e4m3fn 타입으로 저장되어 스케일 자체도 저정밀도를 유지한다.
5. 가중치 후처리: 레이아웃 변환
백엔드에 따라 가중치와 스케일의 메모리 레이아웃을 변환한다.
def process_weights_after_loading(self, layer) -> None:
layer.alpha = Parameter(
1 / (layer.input_global_scale * layer.weight_global_scale),
requires_grad=False,
)
if get_fp4_gemm_runner_backend().is_flashinfer_trtllm():
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
weight_scale = shuffle_matrix_sf_a(
weight_scale.view(torch.uint8), epilogue_tile_m
).reshape(weight_scale.shape).view(torch.float8_e4m3fn)
else:
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
alpha는 1 / (input_global_scale * weight_global_scale)로 계산되어, GEMM 출력을 원래 스케일로 복원하는 역할을 한다.
6. FP4 추론 경로
추론 시에는 입력을 FP4로 양자화한 뒤 FP4 GEMM 커널을 호출한다.
def apply_weights(self, layer, x, bias=None):
x_fp4, x_blockscale = fp4_quantize(x, layer.input_global_scale)
out = fp4_gemm(
x_fp4, w, x_blockscale, w_blockscale,
layer.alpha, output_dtype, w_n,
)
if bias is not None:
out = out + bias
return out.view(*output_shape)
fp4_quantize는 BF16/FP16 입력을 FP4와 블록 스케일로 변환한다. fp4_gemm은 백엔드(CUTLASS/FlashInfer)에 따라 최적화된 커널을 실행한다.
FP8 vs FP4 비교
| 항목 | FP8 (E4M3) | FP4 (NVFP4) |
|---|---|---|
| 비트 수 | 8 | 4 |
| 가중치 메모리 | 0.5x (FP16 대비) | 0.25x |
| 스케일 구조 | 텐서/블록별 float32 | 전역 float32 + 그룹별 fp8 |
| 그룹 크기 | 128 (블록) | 16 |
| 최소 GPU | SM80 (A100) | SM100 (Blackwell) |
| 활성화 양자화 | FP8 | FP4 (W4A4) |
| 패킹 | 1:1 | 2개 FP4 → 1 uint8 |
| 백엔드 | Triton/DeepGEMM | CUTLASS/FlashInfer cuDNN |
설계 근거
- 이중 스케일 체계: 4비트는 표현 범위가 극히 제한적이므로, 전역 스케일로 큰 범위를 맞추고 그룹별 스케일로 세밀한 조정을 한다.
- SM120 우선 설계: FP4 GEMM은 Blackwell의 FP4 Tensor Core에 최적화되어 있으며, SM100 미만에서는 지원하지 않는다.
- 패킹 최적화: 2개의 FP4를 1바이트에 패킹하여 메모리 대역폭 활용을 극대화한다.
- 백엔드 자동 전환: Blackwell에서 CUTLASS가 불안정한 경우를 감지하여 cuDNN으로 자동 폴백한다.
관련 포스트
참고
- SGLang 소스:
python/sglang/srt/layers/quantization/fp4_utils.py - CompressedTensors W4A4:
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py - NVIDIA FP4 포맷: NVIDIA Blackwell Architecture
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] FP8: 8비트 부동소수점 양자화의 구현과 성능
- 현재글 : [SGLang] FP4: 4비트 부동소수점 양자화 (NVIDIA NF4)
- 다음글 [SGLang] AWQ: 활성화 인식 가중치 양자화
댓글