본문으로 건너뛰기

[SGLang] FP8: 8비트 부동소수점 양자화의 구현과 성능

들어가며

LLM 추론에서 메모리와 연산량은 모델 크기에 비례하여 증가한다. FP8 양자화는 FP16 대비 가중치 크기를 절반으로 줄이면서도 정확도 손실을 최소화하는 기법이다. SGLang은 python/sglang/srt/layers/quantization/fp8.py에서 FP8 양자화를 구현하며, 텐서별/채널별/블록별 양자화와 동적/정적 활성화 양자화를 모두 지원한다.

구조도

Fp8Config (설정)
├── activation_scheme: "static" | "dynamic"
├── weight_block_size: [block_n, block_k] | None
├── use_mxfp8: bool
└── ignored_layers: List[str]

Fp8Config.get_quant_method()
├── LinearBase  → Fp8LinearMethod
├── FusedMoE    → Fp8MoEMethod
└── RadixAttention → Fp8KVCacheMethod

핵심 코드 분석

1. Fp8Config 초기화

FP8 설정은 체크포인트 직렬화 여부, 활성화 스킴, 블록 크기를 관리한다.

class Fp8Config(QuantizationConfig):
    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
        activation_scheme: str = "dynamic",
        ignored_layers: Optional[List[str]] = None,
        weight_block_size: List[int] = None,
        use_mxfp8: bool = False,
    ) -> None:
        self.activation_scheme = activation_scheme
        self.weight_block_size = weight_block_size
        self.use_mxfp8 = use_mxfp8
        if self.use_mxfp8:
            if weight_block_size is None:
                weight_block_size = [1, 32]
            elif weight_block_size != [1, 32]:
                raise ValueError("MXFP8 requires weight_block_size=[1, 32].")

MXFP8은 Microscaling FP8로, 블록 크기가 [1, 32]로 고정된다. 일반 FP8 블록 양자화는 [128, 128] 등 다양한 크기를 지원한다.

2. 양자화 메서드 디스패치

get_quant_method는 레이어 타입에 따라 적절한 양자화 메서드를 반환한다.

def get_quant_method(self, layer, prefix):
    if isinstance(layer, LinearBase):
        if is_layer_skipped(prefix, self.ignored_layers,
                           fused_mapping=self.packed_modules_mapping):
            return UnquantizedLinearMethod()
        return Fp8LinearMethod(self)
    elif isinstance(layer, FusedMoE):
        return Fp8MoEMethod(self)
    elif isinstance(layer, RadixAttention):
        return Fp8KVCacheMethod(self)

ignored_layers에 포함된 레이어는 양자화를 건너뛰고 원본 정밀도를 유지한다.

3. Fp8LinearMethod: 가중치 생성

Linear 레이어의 가중치 생성은 블록 양자화와 텐서별 양자화를 구분한다.

class Fp8LinearMethod(LinearMethodBase):
    def create_weights(self, layer, input_size_per_partition,
                       output_partition_sizes, ...):
        weight_dtype = (
            torch.float8_e4m3fn
            if self.is_checkpoint_fp8_serialized else params_dtype
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition,
                dtype=weight_dtype),
            input_dim=1, output_dim=0, weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

체크포인트가 FP8로 직렬화된 경우 torch.float8_e4m3fn dtype을 직접 사용한다. 아닌 경우 FP16/BF16으로 로딩 후 후처리 단계에서 양자화한다.

4. 블록 양자화 스케일 관리

블록 양자화는 가중치를 [block_n, block_k] 크기의 블록으로 나누고, 블록마다 독립 스케일을 할당한다.

if self.block_quant:
    scale_dtype = torch.uint8 if self.use_mxfp8 else torch.float32
    scale = BlockQuantScaleParameter(
        data=scale_init(
            (output_size_per_partition + block_n - 1) // block_n,
            (input_size_per_partition + block_k - 1) // block_k,
            dtype=scale_dtype,
        ),
        input_dim=1, output_dim=0, weight_loader=weight_loader,
    )
    scale.format_ue8m0 = self.use_mxfp8
    layer.register_parameter("weight_scale_inv", scale)

MXFP8 모드에서는 스케일이 UE8M0(unsigned E8M0) 포맷의 uint8로 저장되어 스케일 자체의 메모리도 절약한다.

5. 텐서별 양자화 스케일

블록 양자화가 아닌 경우 텐서별(per-tensor) 스케일을 사용한다.

else:
    scale = PerTensorScaleParameter(
        data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
        weight_loader=weight_loader,
    )
    scale[:] = torch.finfo(torch.float32).min
    layer.register_parameter("weight_scale", scale)

출력 파티션별로 독립적인 스케일을 관리하여, Tensor Parallel 환경에서도 정확한 양자화를 보장한다.

6. Marlin 커널 자동 선택

FP8 하드웨어가 없는 GPU에서도 Marlin 커널을 통해 FP8 추론을 가속한다.

self.use_marlin = False
if _is_cuda:
    force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
    auto_enable = can_auto_enable_marlin_fp8()
    self.use_marlin = force_marlin or auto_enable

7. MXFP8 가중치 후처리

MXFP8 모드에서는 FlashInfer 백엔드에 따라 가중치 레이아웃을 변환한다.

def _process_mxfp8_linear_weight_scale(self, layer):
    if get_fp8_gemm_runner_backend().is_flashinfer_trtllm():
        copy_or_rebind_param(layer, "weight",
            shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
            .view(torch.float8_e4m3fn))
    elif get_fp8_gemm_runner_backend().is_flashinfer_cutlass():
        copy_or_rebind_param(layer, "weight_scale_inv",
            block_scale_interleave(scale_u8.contiguous()).contiguous())

FP16 vs FP8 비교

항목 FP16 FP8 (E4M3) FP8 Block (128x128)
비트 수 16 8 8
지수/가수 5/10 4/3 4/3
동적 범위 ~65504 ~448 ~448 (블록별)
가중치 메모리 1x 0.5x 0.5x + 스케일 오버헤드
스케일 오버헤드 없음 텐서당 1개 블록당 1개
정확도 기준선 약간 감소 FP8보다 높음
최소 GPU 모든 GPU SM80+ (A100) SM80+
GEMM 커널 cuBLAS FP16 cuBLAS FP8 / CUTLASS Triton / DeepGEMM

설계 근거

SGLang FP8 구현의 핵심 설계 결정은 다음과 같다.

  1. 블록 양자화 우선: 텐서별 양자화보다 블록별 양자화가 정확도 손실이 적다. 블록 크기 [128, 128]은 GEMM 타일 크기와 일치하여 커널 효율을 극대화한다.
  2. 동적 활성화 양자화 기본값: 정적 양자화는 보정(calibration)이 필요하지만, 동적 양자화는 런타임에 스케일을 계산하여 범용성이 높다.
  3. MXFP8 분리: Microscaling FP8은 블록 크기가 [1, 32]로 고정되어 일반 블록 양자화와 코드 경로를 분리했다.
  4. Marlin 폴백: FP8 네이티브 지원이 없는 GPU(SM80 미만)에서도 Marlin 커널로 weight-only FP8 추론을 지원한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글