본문으로 건너뛰기

[SGLang] AWQ: 활성화 인식 가중치 양자화

들어가며

AWQ(Activation-aware Weight Quantization)는 가중치를 양자화할 때 활성화 분포를 고려하여 중요한 채널에 더 높은 정밀도를 할당하는 기법이다. SGLang은 python/sglang/srt/layers/quantization/awq.py에서 AWQ를 구현하며, 기본 AWQ 커널과 AWQ Marlin 최적화 커널 두 가지 경로를 제공한다.

구조도

AWQ 양자화 구조
┌────────────────────────────────────────────────┐
  AWQConfig                                      
  ├── weight_bits: 4                             
  ├── group_size: 128                            
  ├── zero_point: bool                           
  └── pack_factor: 32 // weight_bits = 8         
                                                  
  AWQMarlinConfig (Marlin 최적화 버전)            
  ├── quant_type: scalar_types.uint4             
  └── Marlin GPU 커널 사용                        
└────────────────────────────────────────────────┘

디스패치 흐름:
  AWQConfig.get_quant_method()
  ├── LinearBase  AWQLinearMethod (기본)
                   또는 AWQLinearAscendMethod (NPU)
  └── FusedMoE    AWQMoEAscendMethod (NPU만)

  AWQMarlinConfig.get_quant_method()
  ├── LinearBase  AWQMarlinLinearMethod
  └── FusedMoE    AWQMoEMethod

핵심 코드 분석

1. AWQConfig: 4비트 전용 설정

AWQ는 현재 4비트 가중치 양자화만 지원한다.

class AWQConfig(QuantizationConfig):
    def __init__(self, weight_bits: int, group_size: int,
                 zero_point: bool,
                 modules_to_not_convert: Optional[List[str]] = None):
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.zero_point = zero_point
        if self.weight_bits != 4:
            raise ValueError(
                "Currently, only 4-bit weight quantization is "
                f"supported for AWQ, but got {self.weight_bits} bits."
            )
        self.pack_factor = 32 // self.weight_bits  # 8

pack_factor = 8은 32비트 정수 하나에 8개의 4비트 가중치를 패킹한다는 의미이다.

2. 플랫폼별 디스패치

AWQ는 CUDA, HIP(AMD), XPU(Intel), NPU에 따라 다른 역양자화 커널을 사용한다.

if _is_cuda:
    from sglang.jit_kernel.awq_dequantize import awq_dequantize
    from sglang.jit_kernel.awq_marlin_repack import (
        awq_marlin_moe_repack, awq_marlin_repack,
    )
elif _is_hip:
    from sglang.srt.layers.quantization.awq_triton import (
        awq_dequantize_triton as awq_dequantize,
    )
elif _is_xpu:
    from sgl_kernel import awq_dequantize

CUDA에서는 JIT 컴파일된 커널을, AMD GPU에서는 Triton 기반 커널을 사용한다.

3. AWQLinearMethod: 가중치 생성

AWQ 가중치는 그룹 단위로 양자화되며, qweight/qzeros/scales 세 텐서로 구성된다.

class AWQLinearMethod(LinearMethodBase):
    def create_weights(self, layer, input_size_per_partition,
                       output_partition_sizes, ...):
        if input_size_per_partition % self.quant_config.group_size != 0:
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape.")
        output_size_per_partition = sum(output_partition_sizes)
        if output_size_per_partition % self.quant_config.pack_factor != 0:
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape.")

입력 크기가 group_size로, 출력 크기가 pack_factor로 나누어 떨어져야 한다. 이는 Tensor Parallel 환경에서 파티션 크기가 양자화 그룹과 정렬되어야 하기 때문이다.

4. AWQ Marlin: Marlin 커널 자동 전환

AWQ Marlin은 Marlin GPU 커널을 활용하여 AWQ 추론을 가속한다.

class AWQMarlinConfig(QuantizationConfig):
    TYPE_MAP = {
        4: scalar_types.uint4,
        8: scalar_types.uint8,
    }

    @classmethod
    def override_quantization_method(cls, hf_quant_cfg, user_quant):
        can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
        is_valid_user_quant = (
            user_quant is None or user_quant == "marlin"
            or user_quant == "awq_marlin"
        )
        if can_convert and is_valid_user_quant:
            return cls.get_name()  # "awq_marlin"

사용자가 명시적으로 awq를 지정하지 않으면, Marlin 호환 가능한 경우 자동으로 awq_marlin으로 전환한다. 이는 상당한 성능 향상을 제공한다.

5. Marlin 호환성 검사

모든 AWQ 모델이 Marlin으로 변환 가능한 것은 아니다.

@classmethod
def is_awq_marlin_compatible(cls, quant_config):
    quant_method = quant_config.get("quant_method", "").lower()
    num_bits = quant_config.get("bits")
    group_size = quant_config.get("group_size")
    zero_point = quant_config.get("zero_point")

    if not _is_cuda:
        return False
    if quant_method != "awq":
        return False
    if num_bits not in cls.TYPE_MAP:
        return False
    return check_marlin_supported(
        quant_type=cls.TYPE_MAP[num_bits],
        group_size=group_size, has_zp=zero_point
    )

Marlin은 CUDA 전용이며, 특정 비트 수와 그룹 크기 조합만 지원한다.

6. 레이어 스킵 로직

특정 레이어는 양자화에서 제외할 수 있다.

def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
    return any(module_name in prefix
               for module_name in modules_to_not_convert)

lm_headembed_tokens 같은 레이어는 양자화 시 성능 저하가 크므로 제외하는 것이 일반적이다.

7. FusedMoE 레이어 폴백

FusedMoE 레이어가 Marlin을 지원하지 않으면 WNA16 커널로 폴백한다.

def get_quant_method(self, layer, prefix):
    if isinstance(layer, FusedMoE):
        if not check_moe_marlin_supports_layer(layer, self.group_size):
            logger.warning_once(
                f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
                "Falling back to Moe WNA16 kernels."
            )
            return MoeWNA16Config.from_config(self.full_config) \
                .get_quant_method(layer, prefix)
        return AWQMoEMethod(self)

8. 설정 파일 자동 감지

AWQ 모델은 두 가지 설정 파일명을 사용한다.

@staticmethod
def get_config_filenames() -> List[str]:
    return [
        "quant_config.json",      # casperhansen/vicuna-7b-v1.5-awq
        "quantize_config.json",   # abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
    ]

AWQ vs GPTQ 비교

항목 AWQ GPTQ
양자화 기준 활성화 분포 기반 Hessian 행렬 기반
보정 데이터 적은 양 필요 상대적으로 많이 필요
양자화 시간 빠름 느림
비트 수 4비트 2/3/4/8비트
정확도 높음 (활성화 인식) 높음 (2차 최적화)
Marlin 지원 AWQ Marlin GPTQ Marlin
최소 GPU SM75 (Turing) SM80 (Ampere)

설계 근거

  1. 활성화 인식: 중요하지 않은 채널의 가중치는 낮은 정밀도로, 중요한 채널은 스케일링으로 보호하여 전체 정확도를 유지한다.
  2. Marlin 자동 전환: Marlin 커널은 AWQ의 INT4 GEMM을 GPU 텐서 코어에 최적화하여 2-4배 처리량 향상을 달성한다.
  3. MoE 폴백 전략: Marlin이 지원하지 않는 MoE 구조에서도 WNA16 커널로 양자화 추론이 가능하다.
  4. 크로스 플랫폼: CUDA, HIP, XPU, NPU 각각에 최적화된 역양자화 커널을 제공한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글