본문으로 건너뛰기

[llm-compressor] MinMax Observer: 세 가지 min/max 계산 정책

들어가며

Min/max observer는 가장 단순하다. 관측 텐서의 최솟값과 최댓값을 그대로 스케일 계산에 쓴다. 하지만 "어느 시점의 min/max를 쓰느냐"에 따라 세 가지 변형이 존재한다. 기억이 없는(memoryless), 전체 기간을 누적하는(static), 이동 평균을 쓰는(moving average). src/llmcompressor/observers/min_max.py를 분석한다.

핵심 구조/코드 분석

_get_min_max: 공통 헬퍼

def _get_min_max(observed: torch.Tensor) -> MinMaxTuple:
    """
    observed.shape = (num_observations, *qparam_shape, group_size)
    dim=(0, -1) 로 axis reduce → qparam_shape 만 남음
    """
    min_vals = torch.amin(observed, dim=(0, -1))
    max_vals = torch.amax(observed, dim=(0, -1))
    return min_vals, max_vals

Observers Baseflatten_for_calibration이 텐서를 (num_observations, *qparam_shape, group_size) 형태로 만든다. 이 함수는 첫 번째 축(관측 인덱스)과 마지막 축(그룹 내 원소)을 줄여서 qparam_shape만 남긴다. 예를 들어 채널 단위 양자화라면 출력 채널 수 만큼의 min/max 쌍이 반환된다.

MemorylessMinMaxObserver: 현재 관측만 사용

@Observer.register("memoryless_minmax")
class MemorylessMinMaxObserver(Observer):
    """현재 호출에서 본 값으로만 min/max 계산. 과거 값은 기억 안 함"""

    def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
        return _get_min_max(observed)

    def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
        return _get_min_max(observed)

가장 단순한 형태. get_min_max가 호출될 때마다 현재 observed 텐서에서 min/max를 구한다. 이전 호출의 결과는 기억하지 않는다. 이는 "한 번의 호출로 충분한" 경우 — 즉 가중치 양자화나 단일 배치 캘리브레이션 — 에 적합하다. 가중치는 시간에 따라 변하지 않으므로 한 번 보면 끝이다.

StaticMinMaxObserver: 누적 min/max

@Observer.register("static_minmax")
class StaticMinMaxObserver(Observer):
    """여러 관측에서 본 min/max 를 누적해서 최종 값 결정"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.past_min_vals = None                  # 지금까지의 최소값
        self.past_max_vals = None                  # 지금까지의 최대값
        self.past_global_min_vals = None
        self.past_global_max_vals = None

    def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
        min_vals, max_vals = _get_min_max(observed)

        # 과거 값이 있으면 element-wise min/max 로 업데이트
        if self.past_min_vals is not None:
            min_vals = torch.min(min_vals, self.past_min_vals)
            max_vals = torch.max(max_vals, self.past_max_vals)

        self.past_min_vals = min_vals
        self.past_max_vals = max_vals

        return min_vals, max_vals

    def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
        # 위와 동일 — global 용 별도 누적 변수 사용
        min_vals, max_vals = _get_min_max(observed)
        if self.past_global_min_vals is not None:
            min_vals = torch.min(min_vals, self.past_global_min_vals)
            max_vals = torch.max(max_vals, self.past_global_max_vals)
        self.past_global_min_vals = min_vals
        self.past_global_max_vals = max_vals
        return min_vals, max_vals

activation 양자화에서 쓰이는 변형이다. 여러 배치를 순회하면서 "지금까지 본 가장 극단적인 min/max"를 유지한다. 한 번이라도 큰 값이 나오면 그 이후로는 그 값이 채택되어, 양자화 범위가 "모든 관측 데이터를 포용"하도록 확장된다.

past_min_valspast_global_min_vals를 분리하는 이유는 일반 스케일과 전역 스케일(마이크로스케일 스킴)이 독립적으로 누적되어야 하기 때문이다.

MinMaxObserver: 이동 평균 버전

@Observer.register("minmax")
class MinMaxObserver(MovingAverageObserverBase):
    """Moving average 로 min/max 평균 계산"""

    def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
        return _get_min_max(observed)

    def get_current_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
        return _get_min_max(observed)

이 variant는 MovingAverageObserverBase를 상속한다. get_current_min_max는 "이번 관측의 로컬 min/max"만 반환하고, 실제 누적/평균 로직은 부모 클래스가 처리한다. 덕분에 각 서브클래스는 "현재 관측에서 어떻게 min/max를 뽑는가"만 구현하면 된다.

세 variants의 레지스트리 이름이 중요하다.

  • "memoryless_minmax"MemorylessMinMaxObserver
  • "static_minmax"StaticMinMaxObserver
  • "minmax"MinMaxObserver (기본값)

이 중 "minmax"는 기본값이며, 내부적으로는 moving average 기반이다. 사용자가 observer_type: minmax를 쓸 때 가장 일반적인 동작을 얻는다.

언제 어떤 variant를 쓰는가

시나리오 추천 variant
가중치 양자화 memoryless — 한 번의 관측이면 충분
단일 배치 캘리브레이션 memoryless
여러 배치 activation 양자화 + 외곽값 포함 중요 static — 극값 수렴
여러 배치 activation 양자화 + 평균적 분포 minmax (moving average)

왜 이 설계인가

1. 세 variants의 코드 복제 최소화. 모두 _get_min_max 공통 헬퍼를 공유한다. 실제 차이는 "누적 정책"뿐이어서 각 클래스가 수십 줄로 끝난다.

2. past_* 필드를 인스턴스 변수로. Static variant가 누적 상태를 인스턴스에 보관하면 파이프라인 루프에서 여러 번 호출되어도 동일 observer가 상태를 유지한다. GPTQ 같은 알고리즘이 레이어별로 observer를 생성해 사용할 때 이 상태가 자동으로 격리된다.

3. get_current_min_maxget_min_max 이름 구분. Moving average variant는 베이스 계약이 다르다. 부모(MovingAverageObserverBase)가 get_current_*를 요구하므로, MinMaxObserver는 이를 구현한다. 같은 _get_min_max 헬퍼를 쓰지만 호출 경로가 다르다.

4. global 변종 동일 패턴. get_global_min_maxget_min_max와 같은 패턴을 따른다. 마이크로스케일 스킴 지원이 추후에 추가되었지만 기존 코드와 자연스럽게 통합된다.

5. 레지스트리 이름에 의미 부여. "memoryless_*", "static_*", 기본 이름 "minmax"라는 네이밍 자체가 문서 역할을 한다. 사용자가 레시피 YAML에 observer_type: "memoryless_minmax"를 쓰면 의도가 즉시 드러난다.

마무리

MinMax observer는 단순하지만 활용도가 높다. FP8/INT8 weight-only 양자화에서는 memoryless, activation 양자화에서는 moving average 기반 기본값을 쓴다. 다음 글은 양자화 오차를 grid search로 최소화하는 MSE Observer를 본다.

참고 자료

댓글

관련 포스트

llm-compressor 의 다른글