[SGLang] FP8: 8비트 부동소수점 양자화의 구현과 성능
들어가며
LLM 추론에서 메모리와 연산량은 모델 크기에 비례하여 증가한다. FP8 양자화는 FP16 대비 가중치 크기를 절반으로 줄이면서도 정확도 손실을 최소화하는 기법이다. SGLang은 python/sglang/srt/layers/quantization/fp8.py에서 FP8 양자화를 구현하며, 텐서별/채널별/블록별 양자화와 동적/정적 활성화 양자화를 모두 지원한다.
구조도
Fp8Config (설정)
├── activation_scheme: "static" | "dynamic"
├── weight_block_size: [block_n, block_k] | None
├── use_mxfp8: bool
└── ignored_layers: List[str]
Fp8Config.get_quant_method()
├── LinearBase → Fp8LinearMethod
├── FusedMoE → Fp8MoEMethod
└── RadixAttention → Fp8KVCacheMethod
핵심 코드 분석
1. Fp8Config 초기화
FP8 설정은 체크포인트 직렬화 여부, 활성화 스킴, 블록 크기를 관리한다.
class Fp8Config(QuantizationConfig):
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
weight_block_size: List[int] = None,
use_mxfp8: bool = False,
) -> None:
self.activation_scheme = activation_scheme
self.weight_block_size = weight_block_size
self.use_mxfp8 = use_mxfp8
if self.use_mxfp8:
if weight_block_size is None:
weight_block_size = [1, 32]
elif weight_block_size != [1, 32]:
raise ValueError("MXFP8 requires weight_block_size=[1, 32].")
MXFP8은 Microscaling FP8로, 블록 크기가 [1, 32]로 고정된다. 일반 FP8 블록 양자화는 [128, 128] 등 다양한 크기를 지원한다.
2. 양자화 메서드 디스패치
get_quant_method는 레이어 타입에 따라 적절한 양자화 메서드를 반환한다.
def get_quant_method(self, layer, prefix):
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
elif isinstance(layer, RadixAttention):
return Fp8KVCacheMethod(self)
ignored_layers에 포함된 레이어는 양자화를 건너뛰고 원본 정밀도를 유지한다.
3. Fp8LinearMethod: 가중치 생성
Linear 레이어의 가중치 생성은 블록 양자화와 텐서별 양자화를 구분한다.
class Fp8LinearMethod(LinearMethodBase):
def create_weights(self, layer, input_size_per_partition,
output_partition_sizes, ...):
weight_dtype = (
torch.float8_e4m3fn
if self.is_checkpoint_fp8_serialized else params_dtype
)
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition, input_size_per_partition,
dtype=weight_dtype),
input_dim=1, output_dim=0, weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
체크포인트가 FP8로 직렬화된 경우 torch.float8_e4m3fn dtype을 직접 사용한다. 아닌 경우 FP16/BF16으로 로딩 후 후처리 단계에서 양자화한다.
4. 블록 양자화 스케일 관리
블록 양자화는 가중치를 [block_n, block_k] 크기의 블록으로 나누고, 블록마다 독립 스케일을 할당한다.
if self.block_quant:
scale_dtype = torch.uint8 if self.use_mxfp8 else torch.float32
scale = BlockQuantScaleParameter(
data=scale_init(
(output_size_per_partition + block_n - 1) // block_n,
(input_size_per_partition + block_k - 1) // block_k,
dtype=scale_dtype,
),
input_dim=1, output_dim=0, weight_loader=weight_loader,
)
scale.format_ue8m0 = self.use_mxfp8
layer.register_parameter("weight_scale_inv", scale)
MXFP8 모드에서는 스케일이 UE8M0(unsigned E8M0) 포맷의 uint8로 저장되어 스케일 자체의 메모리도 절약한다.
5. 텐서별 양자화 스케일
블록 양자화가 아닌 경우 텐서별(per-tensor) 스케일을 사용한다.
else:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)
출력 파티션별로 독립적인 스케일을 관리하여, Tensor Parallel 환경에서도 정확한 양자화를 보장한다.
6. Marlin 커널 자동 선택
FP8 하드웨어가 없는 GPU에서도 Marlin 커널을 통해 FP8 추론을 가속한다.
self.use_marlin = False
if _is_cuda:
force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
auto_enable = can_auto_enable_marlin_fp8()
self.use_marlin = force_marlin or auto_enable
7. MXFP8 가중치 후처리
MXFP8 모드에서는 FlashInfer 백엔드에 따라 가중치 레이아웃을 변환한다.
def _process_mxfp8_linear_weight_scale(self, layer):
if get_fp8_gemm_runner_backend().is_flashinfer_trtllm():
copy_or_rebind_param(layer, "weight",
shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
.view(torch.float8_e4m3fn))
elif get_fp8_gemm_runner_backend().is_flashinfer_cutlass():
copy_or_rebind_param(layer, "weight_scale_inv",
block_scale_interleave(scale_u8.contiguous()).contiguous())
FP16 vs FP8 비교
| 항목 | FP16 | FP8 (E4M3) | FP8 Block (128x128) |
|---|---|---|---|
| 비트 수 | 16 | 8 | 8 |
| 지수/가수 | 5/10 | 4/3 | 4/3 |
| 동적 범위 | ~65504 | ~448 | ~448 (블록별) |
| 가중치 메모리 | 1x | 0.5x | 0.5x + 스케일 오버헤드 |
| 스케일 오버헤드 | 없음 | 텐서당 1개 | 블록당 1개 |
| 정확도 | 기준선 | 약간 감소 | FP8보다 높음 |
| 최소 GPU | 모든 GPU | SM80+ (A100) | SM80+ |
| GEMM 커널 | cuBLAS FP16 | cuBLAS FP8 / CUTLASS | Triton / DeepGEMM |
설계 근거
SGLang FP8 구현의 핵심 설계 결정은 다음과 같다.
- 블록 양자화 우선: 텐서별 양자화보다 블록별 양자화가 정확도 손실이 적다. 블록 크기
[128, 128]은 GEMM 타일 크기와 일치하여 커널 효율을 극대화한다. - 동적 활성화 양자화 기본값: 정적 양자화는 보정(calibration)이 필요하지만, 동적 양자화는 런타임에 스케일을 계산하여 범용성이 높다.
- MXFP8 분리: Microscaling FP8은 블록 크기가
[1, 32]로 고정되어 일반 블록 양자화와 코드 경로를 분리했다. - Marlin 폴백: FP8 네이티브 지원이 없는 GPU(SM80 미만)에서도 Marlin 커널로 weight-only FP8 추론을 지원한다.
관련 포스트
- SGLang FP4: 4비트 부동소수점 양자화
- SGLang Block-wise INT8: 블록 단위 정수 양자화
- SGLang Compressed Tensors: 통합 양자화 프레임워크
참고
- SGLang 소스:
python/sglang/srt/layers/quantization/fp8.py - SGLang FP8 유틸리티:
python/sglang/srt/layers/quantization/fp8_utils.py - FP8 포맷 표준: OCP Microscaling Formats
관련 포스트
- [sglang] sglang, Qwen3.5-397B FP8 모델 성능 벤치마크 추가 및 CI 개선
- [PyTorch] FlexAttention에 저정밀도 K/V 입력 지원 추가
- [flashinfer] FlashInfer FP8 KV-Cache Prefill 성능 최적화: Repacking 기법을 통한 오버헤드 제거
- [vllm] vLLM의 FP8 Scaled MM 최적화: Padding 제거를 통한 20% 성능 향상
- [vllm] [vLLM 분석] DeepSeek V4의 Sparse FP8 Compressor 커널 최적화: CuteDSL을 통한 성능 극대화
SGLang 의 다른글
- 이전글 [SGLang] Warmup: GPU 초기화와 JIT 사전 컴파일
- 현재글 : [SGLang] FP8: 8비트 부동소수점 양자화의 구현과 성능
- 다음글 [SGLang] FP4: 4비트 부동소수점 양자화 (NVIDIA NF4)
댓글