[llm-compressor] iMatrix Observer: 입력 채널 중요도 가중 MSE
들어가며
MSE Observer는 채널별로 독립적으로 양자화 오차를 최소화한다. 그런데 "모든 채널이 똑같이 중요한가?"라는 질문을 던지면, 그렇지 않다. 어떤 입력 채널은 활성화 크기가 크고 출력에 기여가 크지만, 다른 채널은 거의 0에 가까운 활성화를 가진다. llama.cpp의 iMatrix(importance matrix) 개념은 이 차이를 스케일 결정에 반영한다. 즉 중요한 채널에서 양자화 오차가 크면 더 큰 페널티를 주는 가중 MSE를 쓴다. llm-compressor의 src/llmcompressor/observers/imatrix.py가 이를 구현한다.
공식 문서
핵심 구조/코드 분석
IMatrixMSEObserver 생성자
@Observer.register("imatrix_mse")
class IMatrixMSEObserver(Observer):
"""
MSE observer weighted by per-input-channel importance.
Supports CHANNEL, GROUP, and TENSOR_GROUP for weight-only Linear modules.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
kw = self.args.observer_kwargs
self.maxshrink = kw.get("maxshrink", 0.95) # 최대 95%까지 축소 (훨씬 공격적)
self.patience = kw.get("patience", 5)
self.grid = kw.get("grid", 20) # 20 단계 탐색
self.norm = kw.get("norm", 3.0) # 오차 지수 3 (MSE보다 더 극단 페널티)
self.strict = kw.get("strict", False) # 검증 실패 시 예외 vs 경고
# 파라미터 검증
if self.grid <= 0:
raise ValueError(f"grid must be > 0, got {self.grid}")
if not (0 <= self.maxshrink <= 1):
raise ValueError(f"maxshrink must be in [0, 1], got {self.maxshrink}")
| 파라미터 | 기본값 | 의미 |
|---|---|---|
maxshrink |
0.95 | 최대로 범위를 95%까지 줄일 수 있음. 훨씬 공격적 |
grid |
20 | 탐색 해상도. MSE(100)보다 낮지만 단계가 크다 |
norm |
3.0 | 오차 지수 3. 큰 오차에 훨씬 강한 페널티 |
strict |
False | 검증 실패 시 True면 예외, False면 경고 후 uniform MSE fallback |
파라미터가 MSE Observer와 크게 다른 것이 특징이다. maxshrink=0.95는 "범위를 5%까지 줄여도 좋다"는 것으로, 극단적으로 공격적인 탐색을 허용한다. norm=3.0은 큰 오차 하나를 작은 오차 여러 개보다 훨씬 나쁘게 평가한다. iMatrix의 철학은 "중요한 채널에서의 큰 오차를 무조건 피하라"이다.
attach: Forward Pre-Hook으로 $E[x^2]$ 수집
def attach(self, module: torch.nn.Module) -> None:
"""forward-pre hook 을 등록해 입력 채널별 E[x^2] 누적"""
if hasattr(module, "_imatrix_importance"):
return # 이미 두 번째 pass면 스킵
if not hasattr(module, "in_features"):
return
in_features = module.in_features
module._imatrix_sum = torch.zeros(in_features, dtype=IMATRIX_PRECISION) # FP32 누적기
module._imatrix_count = 0
def _hook(mod, args):
x = args[0] if isinstance(args, tuple) else args
if isinstance(x, tuple):
x = x[0]
if x is None or not isinstance(x, torch.Tensor):
return
x_f = x.detach().to(IMATRIX_PRECISION) # FP32 로 캐스트
n_tokens = math.prod(x_f.shape[:-1]) # 전체 토큰 수
token_sum = x_f.pow(2).sum(dim=list(range(x_f.dim() - 1))) # 채널별 제곱합
if mod._imatrix_sum.device != token_sum.device:
mod._imatrix_sum = mod._imatrix_sum.to(token_sum.device)
mod._imatrix_sum.add_(token_sum)
mod._imatrix_count += n_tokens
module._imatrix_hook = module.register_forward_pre_hook(_hook)
forward pre-hook은 모듈이 호출되기 직전에 입력 텐서를 캡처한다. iMatrix는 입력의 제곱합을 각 채널별로 누적한다. 수식으로는
$$ \text{imp}c = \frac{1}{N} \sum{t=1}^{N} x_{t,c}^2 $$
여기서 $c$는 입력 채널, $N$은 전체 토큰 수. 이는 해당 채널의 평균 에너지를 의미한다. 에너지가 큰 채널은 출력에 기여가 크므로 "중요하다"고 본다.
FP32로 누적하는 이유는 수치 안정성이다. FP16/BF16은 제곱합이 빠르게 오버플로우될 수 있어 큰 모델에서 위험하다.
detach: 누적 결과를 중요도로 변환
def detach(self, module: torch.nn.Module) -> None:
"""hook 제거 + 중요도 계산 + 정리"""
if hasattr(module, "_imatrix_sum"):
if module._imatrix_count > 0:
importance = module._imatrix_sum / module._imatrix_count
module._imatrix_importance = importance # 다음 pass에서 쓸 수 있게 남겨둠
if hasattr(module, "_imatrix_hook"):
module._imatrix_hook.remove()
del module._imatrix_hook
del module._imatrix_sum
del module._imatrix_count
return
# 최종 정리 pass — _imatrix_importance 까지 제거 (체크포인트에 포함되지 않도록)
if hasattr(module, "_imatrix_importance"):
del module._imatrix_importance
detach는 두 가지 모드로 작동한다.
- 첫 호출: 누적 버퍼로부터
importance = sum / count를 계산해_imatrix_importance에 저장한다. 다음 pass에서 이 값을 쓸 수 있도록 남겨둔다. - 두 번째 호출(최종 정리):
_imatrix_importance를 제거한다. 이는 "양자화가 끝났으니 메타데이터는 체크포인트에 포함시키지 말자"는 의도다.
이 두 단계 설계는 llm-compressor의 실행 흐름과 맞물려 있다. 첫 pass에서 IMatrixGatherer Modifier가 hook을 붙여 통계를 모으고, 두 번째 pass에서 QuantizationModifier가 이 통계를 사용해 실제 양자화를 수행한다. 두 pass가 끝난 후 최종 정리가 한 번 더 일어난다.
get_min_max: 가중 Grid Search
def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
return _grid_search(
observed,
self.args,
self.maxshrink,
self.patience,
self.grid,
self.norm,
global_scale=self._get_module_param("global_scale"),
importance_weights=self._prepare_importance(observed), # 중요도 가중치 전달
)
핵심은 _prepare_importance가 중요도 텐서를 _grid_search에 넘긴다는 것이다. 기본 grid search 로직은 MSE Observer와 같지만, 오차 계산에 중요도 가중치가 곱해진다.
# _grid_search 내부
q.sub_(observed_f).abs_().pow_(norm) # |fake_quant(x) - x|^norm
if importance_weights is not None:
q.mul_(importance_weights) # 채널별 중요도로 가중
err = q.sum(dim=(0, -1))
즉 각 채널의 양자화 오차에 그 채널의 중요도가 곱해진다. 중요한 채널에서의 오차는 크게 반영되고, 덜 중요한 채널은 작게 반영된다. 결과적으로 최적 스케일이 "중요한 채널의 양자화 품질"을 우선시한다.
_prepare_importance: 검증과 전처리
def _prepare_importance(self, observed: torch.Tensor) -> Optional[torch.Tensor]:
imp = self._get_validated_importance(observed) # 검증 (아래 참조)
if imp is None:
return None
imp = imp.to(device=observed.device, dtype=torch.float32)
imp = imp / (imp.mean() + torch.finfo(torch.float32).tiny) # 평균 1 로 정규화
module = self.module() if self.module is not None else None
if module is None or not hasattr(module, "weight"):
return None
# 중요도를 가중치와 같은 shape 로 확장
out_features = module.weight.shape[0]
imp_2d = imp.unsqueeze(0).expand(out_features, -1)
g_idx = getattr(module, f"{self.base_name}_g_idx", None)
return flatten_for_calibration(imp_2d, self.base_name, self.args, g_idx)
두 단계 전처리:
- 정규화:
imp / imp.mean()으로 평균을 1로 맞춘다. 이는 grid search의best_error와 같은 스케일에서 비교될 수 있도록 한다. - 확장과 평탄화: 1D 중요도 벡터를
(out_features, in_features)2D로 확장하고,flatten_for_calibration으로 관측 텐서와 같은 shape로 평탄화한다. g_idx가 있으면 재배열까지 처리된다.
검증 로직: 안전한 Fallback
_get_validated_importance는 "이 observer가 이 모듈에 적용 가능한가"를 꼼꼼히 검사한다. 실패 시 strict=True면 예외, strict=False(기본)이면 경고 후 None 반환 → uniform MSE로 fallback.
검증 항목:
base_name == "weight"(activation observer는 미지원)isinstance(module, torch.nn.Linear)(Linear만 지원)strategy != TENSOR(텐서 단위는 미지원)_imatrix_importance가 존재- 1D 텐서
- Finite, non-negative, non-zero values
imp.numel() == expected(strategy에 따른 기대 shape)
조건 하나라도 어긋나면 uniform MSE로 돌아가고 로그에 한 번만 경고한다. 이 방어적 검증 덕분에 사용자는 "어떤 모델에든 imatrix_mse를 지정해두면 지원 모듈에서만 적용되고 나머지는 자동 fallback"을 얻는다.
왜 이 설계인가
1. Forward pre-hook으로 비파괴적 수집. iMatrix 수집이 모델의 forward 동작을 변경하지 않는다. 훅이 입력만 관찰하고, 계산된 통계는 모듈 속성에 저장된다. 제거 시 원본 동작으로 완전 복원.
2. FP32 누적. IMATRIX_PRECISION = torch.float32로 고정. 큰 모델의 활성화 제곱합은 BF16으로는 불안정해서 중요도가 왜곡된다. 약간의 메모리를 더 써도 정확성을 택한다.
3. 두 pass 재사용. 첫 pass에서 통계 수집, 두 번째 pass에서 활용, 세 번째 pass에서 정리. 세 pass가 같은 _imatrix_importance 속성을 매개로 정보를 전달한다. 별도 전역 저장소가 필요 없다.
4. 매우 공격적인 기본값. maxshrink=0.95, norm=3.0은 "중요한 채널의 양자화 오차를 절대 용납하지 않는다"는 정책. uniform MSE(0.20, 2.4)와 크게 다른데, 이는 iMatrix가 2비트 같은 극단 양자화에서 쓰이기 때문이다.
5. 다층 방어 검증. 지원하지 않는 경우가 많지만(activation, 비-Linear, TENSOR strategy 등), 모두 graceful fallback으로 처리한다. 사용자는 지원 여부를 신경 쓰지 않고 레시피를 작성해도 된다.
마무리
iMatrix Observer는 llama.cpp의 중요도 개념을 llm-compressor에 이식한 것이다. 2비트·3비트 같은 극단 양자화에서 정확도 보존에 필수적이다. 이로써 네 개의 observer를 모두 살펴봤다. 다음 글부터는 양자화 Modifier를 본격적으로 파고드는 Quantization Base로 넘어간다.
참고 자료
관련 포스트
llm-compressor 의 다른글
- 이전글 [llm-compressor] Moving Average Observer: 지수 이동 평균 기반 온라인 관측자
- 현재글 : [llm-compressor] iMatrix Observer: 입력 채널 중요도 가중 MSE
- 다음글 [llm-compressor] Quantization Base: QuantizationModifier와 QuantizationMixin
댓글