본문으로 건너뛰기

[llm-compressor] MSE Observer: Grid Search로 양자화 오차 최소화

들어가며

MinMax observer는 단순하지만 극단값(outlier)에 취약하다. 한 채널에 극단적 값이 하나만 있어도 스케일이 크게 벌어지고, 정상 범위의 값들은 양자화 레벨 대부분을 낭비하게 된다. 이를 해결하는 고전적 접근이 MSE(Mean Squared Error) 최소화다. 양자화 범위를 조금씩 줄여가며 "어느 범위에서 양자화 오차가 가장 작은지"를 찾는다. src/llmcompressor/observers/mse.pyMemorylessMSEObserverMovingAverageMSEObserver를 분석한다.

핵심 구조/코드 분석

MemorylessMSEObserver: 한 번의 관측에서 최적 범위 탐색

@Observer.register("memoryless_mse")
class MemorylessMSEObserver(Observer):
    """
    Compute quantization parameters by finding the optimal min/max values which
    minimize the mean of quantization error squared

    mse_quant_error := mean((x - fake_quant(x))**2)
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        observer_kwargs = self.args.observer_kwargs
        self.maxshrink = observer_kwargs.get("maxshrink", 0.20)   # 최대 축소 비율
        self.patience = observer_kwargs.get("patience", 5)        # early stopping 인내
        self.grid = observer_kwargs.get("grid", 100.0)            # 탐색 해상도
        self.norm = observer_kwargs.get("norm", 2.4)              # 오차 지수 (2=MSE)

    def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
        global_scale = self._get_module_param("global_scale")
        return _grid_search_mse(
            observed,
            self.args,
            self.maxshrink,
            self.patience,
            self.grid,
            self.norm,
            global_scale=global_scale,
            optimize_global_scale=False,
        )

    def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
        return _grid_search_mse(
            observed,
            self.args,
            self.maxshrink,
            self.patience,
            self.grid,
            self.norm,
            global_scale=None,
            optimize_global_scale=True,    # global scale 도 함께 최적화
        )
파라미터 기본값 의미
maxshrink 0.20 최대로 범위를 얼마나 줄일지 (20%)
patience 5 연속 N 회 개선 없으면 조기 중단
grid 100.0 탐색 해상도. 높을수록 미세한 shrink 탐색
norm 2.4 오차 지수. 2=순수 MSE, 2.4=약간 보수적

norm=2.4는 흥미롭다. 순수 MSE(2)가 아니라 약간 큰 지수를 쓴다. 이는 큰 오차에 더 큰 페널티를 주기 위함이다. 2.4는 실험적으로 잘 작동하는 값으로 GPTQ 논문 이후 PTQ 구현들의 관습이다.

_grid_search_mse: 핵심 탐색 알고리즘

def _grid_search_mse(
    observed: torch.Tensor,                  # (num_observations, *qparam_shape, group_size)
    args: QuantizationArgs,                  # 양자화 설정
    maxshrink: float,                        # 최대 축소 비율 (예: 0.20)
    patience: float,                         # early stopping
    grid: float,                             # 탐색 해상도 (예: 100)
    norm: float,                             # 오차 지수
    global_scale: Optional[torch.Tensor] = None,
    optimize_global_scale: bool = False,
) -> MinMaxTuple:
    # 1) 초기 min/max (shrink=0) 계산
    min_val = torch.amin(observed, dim=(0, -1))
    max_val = torch.amax(observed, dim=(0, -1))

    # 2) 탐색 중 가장 좋은 결과를 저장할 버퍼 초기화
    best_error = torch.full_like(min_val, torch.finfo(min_val.dtype).max)
    best_min_val = min_val.clone()
    best_max_val = max_val.clone()

    no_improve_count = 0

    # 3) Grid search 루프 — shrink factor p 를 1.0 부터 0.8 까지 점진적으로 줄임
    for i in range(int(maxshrink * grid)):  # i=0..20 (기본값 기준)
        p = 1 - i / grid                     # 첫 iter: 1.0, 마지막: 0.8

        shrinked_min_val = p * min_val       # 범위 축소
        shrinked_max_val = p * max_val

        # 4) 마이크로스케일이면 global_scale 도 함께 재계산
        if optimize_global_scale:
            global_scale = generate_gparam(shrinked_min_val, shrinked_max_val)

        # 5) 축소된 범위로 scale/zero_point 계산
        candidate_scales, candidate_zero_points = calculate_qparams(
            min_vals=shrinked_min_val,
            max_vals=shrinked_max_val,
            quantization_args=args,
            global_scale=global_scale,
        )

        # 6) Fake quantize 후 원본과의 차이를 계산
        with patch_attr(args, "strategy", QuantizationStrategy.TOKEN):
            q = fake_quantize(
                observed,
                candidate_scales.unsqueeze(-1),
                candidate_zero_points.unsqueeze(-1),
                args,
                global_scale=global_scale,
            ).to(observed.dtype)

        q -= observed            # q = fake_quant(x) - x
        q.abs_()                 # |q|
        q.pow_(norm)             # |q|^norm
        err = torch.sum(q, dim=(0, -1))   # 채널별 오차 집계

        # 7) 개선된 채널만 업데이트 (element-wise)
        tmp = err < best_error
        if torch.any(tmp):
            best_error[tmp] = err[tmp]
            best_min_val[tmp] = shrinked_min_val[tmp]
            best_max_val[tmp] = shrinked_max_val[tmp]
            no_improve_count = 0
        else:
            no_improve_count += 1
            if no_improve_count >= patience:
                break           # early stopping

    return best_min_val, best_max_val

알고리즘의 핵심 아이디어

이 코드는 다음 프로세스를 수행한다.

  1. 출발점: min/max를 실제 관측 극값으로 세팅 (shrink=0, 즉 100% 범위)
  2. shrink factor p를 1.0 → 0.8로 줄여가며 반복
  3. p에 대해 (p·min, p·max) 범위로 fake_quantize를 수행하고 오차 측정
  4. 채널별로 독립적으로 best error 추적 (각 채널이 자신만의 최적 p를 가질 수 있음)
  5. patience 회 연속 개선 없으면 조기 중단

채널별 독립적 업데이트가 핵심이다. tmp = err < best_error는 element-wise 비교이고, best_min_val[tmp] = ...는 개선된 채널만 갱신한다. 즉 한 채널은 p=0.95가 최적이고 다른 채널은 p=0.85가 최적일 수 있는데, 이 코드는 각각 독립적으로 최적을 찾는다.

MovingAverageMSEObserver: 누적 버전

@Observer.register("mse")
class MovingAverageMSEObserver(MovingAverageObserverBase):
    """Moving average 기반 MSE observer — 기본값"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        observer_kwargs = self.args.observer_kwargs
        self.maxshrink = observer_kwargs.get("maxshrink", 0.20)
        self.patience = observer_kwargs.get("patience", 5)
        self.grid = observer_kwargs.get("grid", 100.0)
        self.norm = observer_kwargs.get("norm", 2.4)

    def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
        # 부모가 moving average 로 누적
        global_scale = self._get_module_param("global_scale")
        return _grid_search_mse(observed, self.args, ..., optimize_global_scale=False)

    def get_current_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
        return _grid_search_mse(observed, self.args, ..., optimize_global_scale=True)

"mse"라는 이름으로 등록된 이 variant는 MovingAverageObserverBase를 상속한다. get_current_min_max로 이번 관측의 최적 min/max를 구하고, 부모가 이를 지수 이동 평균으로 누적한다.

왜 이 설계인가

1. Grid search의 단순성. 복잡한 최적화 알고리즘을 쓰지 않고 순차적으로 shrink factor를 줄여간다. 코드가 단순하고 GPU 병렬성을 잘 활용한다. 모든 채널이 동시에 계산되므로 수백 채널 단위도 문제없다.

2. 채널별 독립적 최적. tmp = err < best_error의 element-wise 비교는 각 채널의 최적 shrink factor가 다를 수 있다는 현실을 반영한다. 한 채널이 빨리 수렴하고 다른 채널이 늦게 수렴해도 개별로 처리된다.

3. Early stopping. patience=5 기본값은 모든 채널이 개선을 멈춘 지 5회 연속이면 탐색을 중단한다. grid=100, maxshrink=0.2라면 최대 20회 탐색하지만, 대부분 훨씬 일찍 멈춘다. 이는 채널 수가 많고 데이터가 큰 경우 큰 속도 이점을 준다.

4. norm=2.4의 경험적 선택. 순수 MSE보다 약간 큰 지수가 실전에서 더 좋은 PTQ 정확도를 준다는 것은 GPTQ 논문 시기부터 관찰된 경험적 사실이다. 사용자가 원하면 observer_kwargs로 오버라이드할 수 있다.

5. optimize_global_scale 분기. 마이크로스케일 스킴은 "블록 스케일"과 "전역 스케일"을 모두 최적화해야 하는데, 두 경우를 같은 grid search 코드로 처리한다. 인자 하나로 분기하므로 코드 중복이 없다.

마무리

MSE Observer는 MinMax보다 느리지만 외곽값에 훨씬 강건하다. AWQ나 일부 정밀 양자화 시나리오에서 기본값으로 쓰인다. 다음 글은 여러 배치를 누적하는 Moving Average Observer의 베이스를 본다.

참고 자료

댓글

관련 포스트

llm-compressor 의 다른글