본문으로 건너뛰기

[SGLang] W4A8, W8A8, W4A4: 혼합 정밀도 양자화 스킴

들어가며

양자화의 핵심 설계 결정은 가중치와 활성화에 각각 몇 비트를 할당할 것인가이다. SGLang의 Compressed Tensors 프레임워크는 schemes/ 디렉터리에서 W8A8 FP8, W8A8 INT8, W4A4 NVFP4, W8A16 FP8, WNA16 등 다양한 스킴을 구현한다. 각 스킴은 독립적인 가중치 생성, 후처리, 추론 경로를 갖는다.

구조도

compressed_tensors/schemes/
├── compressed_tensors_scheme.py          # 추상 기반 클래스
├── compressed_tensors_w8a8_fp8.py        # W8A8 FP8
├── compressed_tensors_w8a8_int8.py       # W8A8 INT8
├── compressed_tensors_w4a4_nvfp4.py      # W4A4 NVFP4
├── compressed_tensors_w8a16_fp8.py       # W8A16 FP8 (weight-only)
├── compressed_tensors_wNa16.py           # WNA16 (Marlin, N=4/8)
├── compressed_tensors_w8a8_fp8_moe.py    # MoE FP8
├── compressed_tensors_w8a8_int8_moe.py   # MoE INT8
├── compressed_tensors_w4a4_nvfp4_moe.py  # MoE NVFP4
├── compressed_tensors_w4a4_mxint4_moe.py # MoE MxINT4
├── compressed_tensors_w4a8_int8_moe.py   # MoE W4A8
└── compressed_tensors_wNa16_moe.py       # MoE WNA16

핵심 코드 분석

1. W8A8 FP8: 채널별/텐서별/블록별 전략

W8A8 FP8은 가장 유연한 스킴으로, 세 가지 양자화 전략을 지원한다.

class CompressedTensorsW8A8Fp8(CompressedTensorsLinearScheme):
    def __init__(self, weight_quant, is_static_input_scheme):
        self.strategy = self.weight_quant.strategy
        self.is_static_input_scheme = is_static_input_scheme
        self.weight_block_size = self.weight_quant.block_structure

전략에 따라 스케일 파라미터가 달라진다:

# 채널별: [N, 1] 스케일
if self.strategy == QuantizationStrategy.CHANNEL:
    weight_scale = ChannelQuantScaleParameter(
        data=torch.empty((sum(output_partition_sizes), 1),
                         dtype=torch.float32),
        output_dim=0, weight_loader=weight_loader,
    )
# 텐서별: 파티션당 1개 스케일
elif self.strategy == QuantizationStrategy.TENSOR:
    weight_scale = PerTensorScaleParameter(
        data=torch.empty(len(output_partition_sizes),
                         dtype=torch.float32),
        weight_loader=weight_loader,
    )
# 블록별: [N/bn, K/bk] 스케일
elif self.strategy == QuantizationStrategy.BLOCK:
    weight_scale = BlockQuantScaleParameter(...)

2. W8A8 FP8 추론 경로

추론 시에는 전략에 따라 다른 커널을 호출한다.

def apply_weights(self, layer, x, bias=None):
    if self.weight_block_size is not None:
        return self.w8a8_block_fp8_linear(
            input=x, weight=layer.weight,
            block_size=self.weight_block_size,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale, bias=bias,
        )
    if _use_aiter and self.strategy == QuantizationStrategy.CHANNEL:
        return apply_fp8_ptpc_linear(...)  # AMD AIter 최적화
    else:
        return apply_fp8_linear(...)       # 범용 경로

AMD GPU에서는 apply_fp8_ptpc_linear(per-token per-channel)로 최적화된 커널을 사용한다.

3. W8A8 INT8: 정수 양자화

INT8은 정수 연산을 활용하여 추론 속도를 높인다.

class CompressedTensorsW8A8Int8(CompressedTensorsLinearScheme):
    def __init__(self, strategy, is_static_input_scheme, input_symmetric):
        self.strategy = strategy
        self.is_static_input_scheme = is_static_input_scheme
        self.input_symmetric = input_symmetric

INT8의 독특한 점은 비대칭 입력 양자화를 지원한다는 것이다:

def apply_weights(self, layer, x, bias=None):
    x_q, x_scale = per_token_quant_int8(x)
    return int8_scaled_mm(
        x_q, layer.weight, x_scale, layer.weight_scale,
        out_dtype=x.dtype, bias=bias
    )

per_token_quant_int8은 토큰별로 INT8 양자화를 수행하고, int8_scaled_mmsgl_kernel의 최적화된 INT8 행렬곱을 실행한다.

4. W8A8 INT8 비대칭 양자화: AZP 보정

비대칭 양자화 시에는 Zero Point(AZP) 보정이 필요하다.

if not self.input_symmetric:
    weight = layer.weight
    azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
    if self.is_static_input_scheme:
        azp_adj = layer.input_zero_point * azp_adj
    layer.azp_adj = Parameter(azp_adj, requires_grad=False)

azp_adj는 가중치의 열 합으로 계산되며, 양자화 오프셋을 보정한다.

5. W4A4 NVFP4: 극저정밀도

W4A4는 가중치와 활성화 모두 4비트를 사용한다.

class CompressedTensorsW4A4Fp4(CompressedTensorsLinearScheme):
    def __init__(self):
        self.group_size = 16

    @classmethod
    def get_min_capability(cls) -> int:
        return 100  # Hopper 이상

SM100(Hopper) 이상에서만 동작하며, 2개의 FP4가 1바이트로 패킹된다. 이중 스케일(전역 + 그룹별)로 정확도를 보존한다.

6. WNA16: Marlin 기반 Weight-Only 양자화

WNA16은 가중치만 양자화하고 활성화는 16비트를 유지한다.

WNA16_SUPPORTED_TYPES_MAP = {
    4: scalar_types.uint4b8,
    8: scalar_types.uint8b128
}
WNA16_ZP_SUPPORTED_TYPES_MAP = {
    4: scalar_types.uint4,
    8: scalar_types.uint8
}
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())

4비트와 8비트를 지원하며, Marlin 커널로 가속한다. Zero Point 유무에 따라 다른 scalar type을 사용한다.

7. 텐서별 스케일 후처리: requantize_with_max_scale

Fused 모듈에서 텐서별 양자화 스케일을 통합하는 핵심 함수이다.

def process_weights_after_loading(self, layer):
    if self.strategy == QuantizationStrategy.TENSOR:
        max_w_scale, weight = requantize_with_max_scale(
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            logical_widths=layer.logical_widths,
        )
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)

QKV가 fused된 경우 Q/K/V 각각의 스케일이 다를 수 있으므로, 최대 스케일로 통합 후 가중치를 재양자화한다.

전체 스킴 비교 표

스킴 가중치 활성화 스케일 전략 최소 GPU 메모리 절감 정확도
W8A8 FP8 (텐서) FP8 E4M3 FP8 E4M3 텐서별 SM89 2x 높음
W8A8 FP8 (채널) FP8 E4M3 FP8 E4M3 채널별 SM89 2x 매우 높음
W8A8 FP8 (블록) FP8 E4M3 FP8 E4M3 블록별 SM89 2x 최고
W8A8 INT8 INT8 INT8 텐서/채널 SM80 2x 높음
W4A4 NVFP4 FP4 FP4 그룹+전역 SM100 4x 중간
W8A16 FP8 FP8 FP16/BF16 텐서/채널 SM89 2x (W만) 높음
WNA16 (N=4) INT4 FP16/BF16 그룹 SM80 4x (W만) 중간-높음
WNA16 (N=8) INT8 FP16/BF16 그룹 SM80 2x (W만) 높음

설계 근거

  1. 전략 기반 추상화: QuantizationStrategy enum으로 TENSOR/CHANNEL/BLOCK/TOKEN/GROUP을 구분하여, 스킴 내부에서 조건 분기를 명확하게 관리한다.
  2. 후처리 단계 분리: create_weightsprocess_weights_after_loadingapply_weights의 3단계로 분리하여, 가중치 로딩과 양자화 변환을 독립적으로 처리한다.
  3. Fused 모듈 호환: QKV fused 모듈에서 logical_widths로 각 파티션의 논리적 크기를 추적하여, 텐서별 스케일 통합을 정확하게 수행한다.
  4. Weight-Only 지원: W8A16, WNA16 스킴으로 활성화 양자화 없이도 가중치 압축의 이점을 얻을 수 있다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글