[SGLang] MoE 전용 양자화: 전문가별 독립 양자화 전략
들어가며
MoE(Mixture of Experts) 모델은 전문가 수만큼 가중치가 존재하므로 양자화의 필요성이 더 크다. 그러나 각 전문가의 가중치 분포가 다르기 때문에, 전문가별 독립적인 스케일 관리가 필요하다. SGLang은 compressed_tensors/schemes/ 디렉터리의 *_moe.py 파일에서 FP8, INT8, NVFP4, MxINT4 등 MoE 전용 양자화 스킴을 구현한다.
구조도
MoE 양자화 스킴 계층 구조
CompressedTensorsMoEScheme (추상 기반)
├── CompressedTensorsW8A8Fp8MoE (FP8 MoE)
├── CompressedTensorsW4A4Nvfp4MoE (NVFP4 MoE)
├── CompressedTensorsMxInt4MoE (MxINT4 MoE)
├── CompressedTensorsWNA16MoE (WNA16 Marlin MoE)
├── NPUCompressedTensorsW8A8Int8DynamicMoE (NPU INT8)
└── NPUCompressedTensorsW4A8Int8DynamicMoE (NPU W4A8)
MoE 가중치 구조 (전문가 E개):
w13_weight: [E, 2*intermediate, hidden] (gate + up fused)
w2_weight: [E, hidden, intermediate] (down projection)
w13_scale: [E, scale_n, scale_k] (전문가별 독립 스케일)
w2_scale: [E, scale_n, scale_k]
핵심 코드 분석
1. FP8 MoE: 텐서/채널/블록 전략
FP8 MoE는 Dense FP8과 달리 세 가지 양자화 전략을 전문가 차원에서 관리한다.
class CompressedTensorsW8A8Fp8MoE(CompressedTensorsMoEScheme):
def __init__(self, weight_quant, input_quant):
per_tensor = (
self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR
)
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN
)
if not (per_tensor or per_channel):
assert self.weight_quant.strategy == QuantizationStrategy.BLOCK
self.weight_block_size = self.weight_quant.block_structure
텐서별, 채널별, 블록별 세 가지 전략에 따라 스케일 텐서의 형태가 달라진다.
2. 전문가별 독립 스케일 생성
각 전략에 따라 스케일 텐서의 차원이 다르다.
# 텐서별: 전문가당 1개 (w13은 gate/up 2개)
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32),
requires_grad=False)
# 채널별: 전문가 x 출력 채널
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size_per_partition,
1, dtype=torch.float32),
requires_grad=False)
# 블록별: 전문가 x 블록 그리드
elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts,
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32),
requires_grad=False)
3. 텐서별 전문가 스케일 통합
w13(gate+up)의 텐서별 스케일은 두 파티션의 최대값으로 통합한다.
def process_weights_after_loading(self, layer):
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_local_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start+shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
layer.w13_weight[expert_id][start:start+shard_size, :], _ = \
scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False)
gate와 up의 스케일이 다르면, 최대 스케일로 통합한 뒤 가중치를 역양자화-재양자화한다. 이는 FP8 MoE 커널이 w13에 대해 단일 스케일만 수용하기 때문이다.
4. MoE 러너 디스패치
양자화 전략에 따라 Triton, FlashInfer TrtLLM, AIter 등 다른 MoE 러너를 사용한다.
def apply_weights(self, layer, dispatch_output):
if _use_aiter and self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
# AMD AIter: per-token per-channel FP8 MoE
output = fused_moe(x, layer.w13_weight, layer.w2_weight,
topk_weights, topk_ids,
quant_type=QuantType.per_Token,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale)
elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
if self.use_flashinfer_trtllm:
quant_info = FlashInferTrtllmFp8MoeQuantInfo(...)
else:
quant_info = TritonMoeQuantInfo(
use_fp8_w8a8=True,
block_shape=self.weight_block_size, ...)
else:
quant_info = TritonMoeQuantInfo(
use_fp8_w8a8=True,
per_channel_quant=..., ...)
return self.runner.run(dispatch_output, quant_info)
5. NVFP4 MoE: Blackwell 전용
NVFP4 MoE는 Blackwell(SM100+)에서만 동작하며, CUTLASS 또는 FlashInfer TrtLLM 커널을 사용한다.
class CompressedTensorsW4A4Nvfp4MoE(CompressedTensorsMoEScheme):
def __init__(self):
if not is_blackwell_supported():
raise ValueError(
"Current platform does not support NVFP4 quantization.")
self.group_size = 16
self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm()
NVFP4 MoE는 가중치, 그룹별 스케일, 전역 스케일, 입력 전역 스케일까지 4종류의 파라미터를 관리한다.
6. NVFP4 MoE TrtLLM 경로
FlashInfer TrtLLM 백엔드는 전용 가중치 레이아웃으로 변환이 필요하다.
def process_weights_after_loading(self, layer):
if self.use_flashinfer_trtllm:
(gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
) = prepare_static_weights_for_trtllm_fp4_moe(
layer.w13_weight, layer.w2_weight,
layer.w13_weight_scale, layer.w2_weight_scale,
hidden_size, intermediate_size, num_experts,
)
layer.g1_scale_c = torch.nn.Parameter(
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False)
TrtLLM 경로에서는 가중치를 셔플하고, 스케일을 미리 조합하여 커널 실행 시 오버헤드를 줄인다.
7. NVFP4 MoE CUTLASS 경로
CUTLASS 경로에서는 blockscale을 swizzle하여 메모리 접근 패턴을 최적화한다.
else:
layer.w13_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False)
layer.cutlass_moe_params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
layer.w13_weight.device,
num_experts=layer.num_experts,
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2,
hidden_size=layer.w13_weight.shape[2] * 2,
)
8. MxINT4 MoE: FlashInfer TrtLLM 전용
MxINT4는 Blackwell에서 FlashInfer TrtLLM 백엔드만 지원한다.
class CompressedTensorsMxInt4MoE(CompressedTensorsMoEScheme):
def __init__(self, quant_config):
config = self.quant_config.target_scheme_map["Linear"].get("weights")
assert (config.strategy == "group"
and config.group_size == 32
and config.num_bits == 4)
assert config.symmetric
assert get_moe_runner_backend().is_flashinfer_trtllm()
그룹 크기 32, 4비트, 대칭 양자화만 지원하는 엄격한 제약이 있다.
MoE 양자화 스킴 비교
| 스킴 | 비트 (W/A) | 스케일 전략 | 최소 GPU | 백엔드 | 특징 |
|---|---|---|---|---|---|
| W8A8 FP8 텐서 | 8/8 | 전문가당 1개 | SM80 | Triton | 가장 단순 |
| W8A8 FP8 채널 | 8/8 | 전문가x채널 | SM80 | Triton/AIter | 높은 정확도 |
| W8A8 FP8 블록 | 8/8 | 전문가x블록 | SM80 | Triton/TrtLLM | 최고 정확도 |
| W4A4 NVFP4 | 4/4 | 그룹+전역 | SM100 | CUTLASS/TrtLLM | 최대 압축 |
| MxINT4 | 4/16 | 그룹(32) | SM100 | TrtLLM만 | INT4 weight-only |
| W8A8 INT8 | 8/8 | 채널별 | SM80 | NPU전용 | Ascend 최적화 |
설계 근거
- 전문가 독립 스케일: 각 전문가의 가중치 분포가 다르므로, 전문가마다 독립적인 스케일을 관리하여 양자화 오차를 최소화한다.
- w13 스케일 통합: gate와 up 프로젝션이 fused되어 있으므로, 두 파티션의 스케일을 통합해야 한다. 역양자화-재양자화 비용은 로딩 시 한 번만 발생한다.
- 다중 백엔드: 같은 양자화 스킴이라도 GPU 벤더(NVIDIA/AMD)와 아키텍처(Hopper/Blackwell)에 따라 최적 커널이 다르므로, 런타임에 자동 선택한다.
- TrtLLM 레이아웃 변환: TensorRT-LLM 커널은 특수한 메모리 레이아웃을 요구하므로, 가중치 로딩 후 한 번만 변환하여 추론 시 오버헤드를 제거한다.
관련 포스트
참고
- FP8 MoE:
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8_moe.py - NVFP4 MoE:
compressed_tensors_w4a4_nvfp4_moe.py - MxINT4 MoE:
compressed_tensors_w4a4_mxint4_moe.py - INT8 MoE:
compressed_tensors_w8a8_int8_moe.py
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] W4A8, W8A8, W4A4: 혼합 정밀도 양자화 스킴
- 현재글 : [SGLang] MoE 전용 양자화: 전문가별 독립 양자화 전략
- 다음글 [SGLang] 하드웨어별 양자화 튜닝: B200, H100, MI300X 최적 설정
댓글