본문으로 건너뛰기

[SGLang] MoE 전용 양자화: 전문가별 독립 양자화 전략

들어가며

MoE(Mixture of Experts) 모델은 전문가 수만큼 가중치가 존재하므로 양자화의 필요성이 더 크다. 그러나 각 전문가의 가중치 분포가 다르기 때문에, 전문가별 독립적인 스케일 관리가 필요하다. SGLang은 compressed_tensors/schemes/ 디렉터리의 *_moe.py 파일에서 FP8, INT8, NVFP4, MxINT4 등 MoE 전용 양자화 스킴을 구현한다.

구조도

MoE 양자화 스킴 계층 구조
CompressedTensorsMoEScheme (추상 기반)
├── CompressedTensorsW8A8Fp8MoE     (FP8 MoE)
├── CompressedTensorsW4A4Nvfp4MoE   (NVFP4 MoE)
├── CompressedTensorsMxInt4MoE      (MxINT4 MoE)
├── CompressedTensorsWNA16MoE       (WNA16 Marlin MoE)
├── NPUCompressedTensorsW8A8Int8DynamicMoE  (NPU INT8)
└── NPUCompressedTensorsW4A8Int8DynamicMoE  (NPU W4A8)

MoE 가중치 구조 (전문가 E개):
w13_weight: [E, 2*intermediate, hidden]  (gate + up fused)
w2_weight:  [E, hidden, intermediate]     (down projection)
w13_scale:  [E, scale_n, scale_k]         (전문가별 독립 스케일)
w2_scale:   [E, scale_n, scale_k]

핵심 코드 분석

1. FP8 MoE: 텐서/채널/블록 전략

FP8 MoE는 Dense FP8과 달리 세 가지 양자화 전략을 전문가 차원에서 관리한다.

class CompressedTensorsW8A8Fp8MoE(CompressedTensorsMoEScheme):
    def __init__(self, weight_quant, input_quant):
        per_tensor = (
            self.weight_quant.strategy == QuantizationStrategy.TENSOR
            and self.input_quant.strategy == QuantizationStrategy.TENSOR
        )
        per_channel = (
            self.weight_quant.strategy == QuantizationStrategy.CHANNEL
            and self.input_quant.strategy == QuantizationStrategy.TOKEN
        )
        if not (per_tensor or per_channel):
            assert self.weight_quant.strategy == QuantizationStrategy.BLOCK
            self.weight_block_size = self.weight_quant.block_structure

텐서별, 채널별, 블록별 세 가지 전략에 따라 스케일 텐서의 형태가 달라진다.

2. 전문가별 독립 스케일 생성

각 전략에 따라 스케일 텐서의 차원이 다르다.

# 텐서별: 전문가당 1개 (w13은 gate/up 2개)
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
    w13_weight_scale = torch.nn.Parameter(
        torch.ones(num_experts, 2, dtype=torch.float32),
        requires_grad=False)
    w2_weight_scale = torch.nn.Parameter(
        torch.ones(num_experts, dtype=torch.float32),
        requires_grad=False)

# 채널별: 전문가 x 출력 채널
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
    w13_weight_scale = torch.nn.Parameter(
        torch.ones(num_experts, 2 * intermediate_size_per_partition,
                   1, dtype=torch.float32),
        requires_grad=False)

# 블록별: 전문가 x 블록 그리드
elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
    w13_weight_scale = torch.nn.Parameter(
        torch.ones(num_experts,
            2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
            (hidden_size + block_k - 1) // block_k,
            dtype=torch.float32),
        requires_grad=False)

3. 텐서별 전문가 스케일 통합

w13(gate+up)의 텐서별 스케일은 두 파티션의 최대값으로 통합한다.

def process_weights_after_loading(self, layer):
    if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
        shard_size = layer.intermediate_size_per_partition
        max_w13_scales = layer.w13_weight_scale.max(dim=1).values
        for expert_id in range(layer.num_local_experts):
            start = 0
            for shard_id in range(2):
                dq_weight = per_tensor_dequantize(
                    layer.w13_weight[expert_id][start:start+shard_size, :],
                    layer.w13_weight_scale[expert_id][shard_id],
                )
                layer.w13_weight[expert_id][start:start+shard_size, :], _ = \
                    scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
                start += shard_size
        layer.w13_weight_scale = torch.nn.Parameter(
            max_w13_scales, requires_grad=False)

gate와 up의 스케일이 다르면, 최대 스케일로 통합한 뒤 가중치를 역양자화-재양자화한다. 이는 FP8 MoE 커널이 w13에 대해 단일 스케일만 수용하기 때문이다.

4. MoE 러너 디스패치

양자화 전략에 따라 Triton, FlashInfer TrtLLM, AIter 등 다른 MoE 러너를 사용한다.

def apply_weights(self, layer, dispatch_output):
    if _use_aiter and self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
        # AMD AIter: per-token per-channel FP8 MoE
        output = fused_moe(x, layer.w13_weight, layer.w2_weight,
            topk_weights, topk_ids,
            quant_type=QuantType.per_Token,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale)
    elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
        if self.use_flashinfer_trtllm:
            quant_info = FlashInferTrtllmFp8MoeQuantInfo(...)
        else:
            quant_info = TritonMoeQuantInfo(
                use_fp8_w8a8=True,
                block_shape=self.weight_block_size, ...)
    else:
        quant_info = TritonMoeQuantInfo(
            use_fp8_w8a8=True,
            per_channel_quant=..., ...)
    return self.runner.run(dispatch_output, quant_info)

5. NVFP4 MoE: Blackwell 전용

NVFP4 MoE는 Blackwell(SM100+)에서만 동작하며, CUTLASS 또는 FlashInfer TrtLLM 커널을 사용한다.

class CompressedTensorsW4A4Nvfp4MoE(CompressedTensorsMoEScheme):
    def __init__(self):
        if not is_blackwell_supported():
            raise ValueError(
                "Current platform does not support NVFP4 quantization.")
        self.group_size = 16
        self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm()

NVFP4 MoE는 가중치, 그룹별 스케일, 전역 스케일, 입력 전역 스케일까지 4종류의 파라미터를 관리한다.

6. NVFP4 MoE TrtLLM 경로

FlashInfer TrtLLM 백엔드는 전용 가중치 레이아웃으로 변환이 필요하다.

def process_weights_after_loading(self, layer):
    if self.use_flashinfer_trtllm:
        (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
         gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
        ) = prepare_static_weights_for_trtllm_fp4_moe(
            layer.w13_weight, layer.w2_weight,
            layer.w13_weight_scale, layer.w2_weight_scale,
            hidden_size, intermediate_size, num_experts,
        )
        layer.g1_scale_c = torch.nn.Parameter(
            (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
            requires_grad=False)

TrtLLM 경로에서는 가중치를 셔플하고, 스케일을 미리 조합하여 커널 실행 시 오버헤드를 줄인다.

7. NVFP4 MoE CUTLASS 경로

CUTLASS 경로에서는 blockscale을 swizzle하여 메모리 접근 패턴을 최적화한다.

else:
    layer.w13_weight_scale = torch.nn.Parameter(
        swizzle_blockscale(layer.w13_weight_scale), requires_grad=False)
    layer.w2_weight_scale = torch.nn.Parameter(
        swizzle_blockscale(layer.w2_weight_scale), requires_grad=False)
    layer.cutlass_moe_params = CutlassMoEParams(
        CutlassMoEType.BlockscaledFP4,
        layer.w13_weight.device,
        num_experts=layer.num_experts,
        intermediate_size_per_partition=layer.w2_weight.shape[2] * 2,
        hidden_size=layer.w13_weight.shape[2] * 2,
    )

8. MxINT4 MoE: FlashInfer TrtLLM 전용

MxINT4는 Blackwell에서 FlashInfer TrtLLM 백엔드만 지원한다.

class CompressedTensorsMxInt4MoE(CompressedTensorsMoEScheme):
    def __init__(self, quant_config):
        config = self.quant_config.target_scheme_map["Linear"].get("weights")
        assert (config.strategy == "group"
                and config.group_size == 32
                and config.num_bits == 4)
        assert config.symmetric
        assert get_moe_runner_backend().is_flashinfer_trtllm()

그룹 크기 32, 4비트, 대칭 양자화만 지원하는 엄격한 제약이 있다.

MoE 양자화 스킴 비교

스킴 비트 (W/A) 스케일 전략 최소 GPU 백엔드 특징
W8A8 FP8 텐서 8/8 전문가당 1개 SM80 Triton 가장 단순
W8A8 FP8 채널 8/8 전문가x채널 SM80 Triton/AIter 높은 정확도
W8A8 FP8 블록 8/8 전문가x블록 SM80 Triton/TrtLLM 최고 정확도
W4A4 NVFP4 4/4 그룹+전역 SM100 CUTLASS/TrtLLM 최대 압축
MxINT4 4/16 그룹(32) SM100 TrtLLM만 INT4 weight-only
W8A8 INT8 8/8 채널별 SM80 NPU전용 Ascend 최적화

설계 근거

  1. 전문가 독립 스케일: 각 전문가의 가중치 분포가 다르므로, 전문가마다 독립적인 스케일을 관리하여 양자화 오차를 최소화한다.
  2. w13 스케일 통합: gate와 up 프로젝션이 fused되어 있으므로, 두 파티션의 스케일을 통합해야 한다. 역양자화-재양자화 비용은 로딩 시 한 번만 발생한다.
  3. 다중 백엔드: 같은 양자화 스킴이라도 GPU 벤더(NVIDIA/AMD)와 아키텍처(Hopper/Blackwell)에 따라 최적 커널이 다르므로, 런타임에 자동 선택한다.
  4. TrtLLM 레이아웃 변환: TensorRT-LLM 커널은 특수한 메모리 레이아웃을 요구하므로, 가중치 로딩 후 한 번만 변환하여 추론 시 오버헤드를 제거한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글