본문으로 건너뛰기

[llm-compressor] SparseGPT: 원샷 LLM 가지치기 구현

들어가며

SparseGPT는 GPTQ와 같은 연구진이 2023년에 발표한 "양자화의 자매 알고리즘"이다. GPTQ가 "가중치를 정수로 반올림"한다면 SparseGPT는 "일부 가중치를 0으로 만든다". 두 알고리즘은 같은 OBS(Optimal Brain Surgeon) 프레임워크를 공유하며, 코드 구조도 거의 동일하다. OPT-175B를 4시간 만에 50% sparsity로 가지치기하면서 perplexity 손실이 무시할 만한 수준이라는 것이 주요 결과다. SparseGPT 논문과 llm-compressor의 src/llmcompressor/modifiers/pruning/sparsegpt/base.py를 분석한다.

공식 문서

논문 핵심 내용

SparseGPT의 기반은 OBS다. "한 가중치를 0으로 만들면 다른 가중치들을 얼마나 보정해야 하는가"를 2차 정보(헤시안)로 계산한다.

$$ \text{importance}i = \frac{w_i^2}{[H^{-1}]{ii}} $$

이 중요도가 낮은 가중치를 우선 제거하고, 남은 가중치를 OBS 공식으로 보정한다.

$$ w_j \leftarrow w_j - w_i \cdot \frac{[H^{-1}]{ij}}{[H^{-1}]{ii}} $$

GPTQ의 양자화 공식과 동일한 수식이다. 차이는 양자화가 "$w_i$를 반올림 값으로 교체"한다면, pruning은 "$w_i = 0$"으로 만든다는 것뿐이다. 두 연산을 통합해 볼 수 있고, 논문도 이를 명시적으로 언급한다.

N:M Sparsity 지원

SparseGPT의 중요한 특징은 2:4 sparsity 같은 N:M 구조를 지원한다는 것이다. 4개마다 2개를 0으로 만드는 제약 조건 하에서 OBS 최적해를 찾는다. 이는 NVIDIA A100 이상에서 하드웨어 가속을 받을 수 있어 실용성이 높다.

벤치마크 (OPT-175B, 논문 기준)

항목 수치
OPT-175B 가지치기 시간 ~4 GPU 시간
50% unstructured sparsity PPL 손실 무시할 수준
2:4 sparsity PPL 손실 약간 증가하나 여전히 수용 가능
지원 sparsity 50%, 60%, 70%, 2:4, 4:8
이전 방법 대비 OBS를 175B 모델에 적용한 최초 사례

핵심 구조/코드 분석

SparseGPTModifier 파라미터

class SparseGPTModifier(SGPTBaseModifier):
    """Implements SparseGPT from https://arxiv.org/abs/2301.00774"""

    # SparseGPT 전용 파라미터
    block_size: int = 128                      # GPTQ 와 동일한 의미
    dampening_frac: float = 0.01               # 헤시안 대각 정규화
    preserve_sparsity_mask: bool = False       # 기존 마스크 보존 모드

    # 베이스에서 상속: sparsity, mask_structure, targets, ignore, _hessians
파라미터 기본값 의미
block_size 128 OBS 블록 업데이트 단위 (GPTQ와 동일)
dampening_frac 0.01 헤시안 대각 정규화
preserve_sparsity_mask False 기존 0 위치를 유지하며 추가 가지치기
sparsity - 목표 희소성 (예: 0.5 = 50%)
mask_structure "unstructured" "2:4", "4:8"

on_start / on_sequential_epoch_end / on_finalize

SparseGPT의 라이프사이클 훅들은 GPTQ와 거의 동일하다.

def on_start(self, state: State, event: Event, **kwargs):
    # 대상 모듈 각각에 헤시안 누적 hook 등록 (GPTQ 와 동일)
    for name, module in match_named_modules(state.model, self.targets):
        self._module_names[module] = name
        self._hessians[module] = make_empty_hessian(module)
        self._num_samples[module] = 0

        def _hook(mod, args):
            x = args[0]
            accumulate_hessian(self._hessians[mod], self._num_samples, mod, x)

        module.register_forward_pre_hook(_hook)


def on_event(self, state: State, event: Event, **kwargs):
    if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
        subgraph = kwargs["subgraph"]
        for module_name in subgraph.consumed_names:
            module = state.model.get_submodule(module_name)
            if module in self._hessians:
                self._sparsify_module(module)
                del self._hessians[module]
                del self._num_samples[module]

차이는 _sparsify_module에 있다.

_sparsify_module: OBS 기반 pruning

def _sparsify_module(self, module):
    W = module.weight.data.clone().float()
    H = self._hessians[module].float()
    rows, cols = W.shape

    # 1) 헤시안 정규화
    damp = self.dampening_frac * torch.mean(torch.diag(H))
    H[torch.arange(cols), torch.arange(cols)] += damp

    # 2) Cholesky 역행렬 (GPTQ 와 동일)
    Hinv = torch.linalg.cholesky(torch.linalg.inv(H), upper=True)

    # 3) block_size 단위로 순회
    Mask = torch.zeros_like(W, dtype=torch.bool)
    for i1 in range(0, cols, self.block_size):
        i2 = min(i1 + self.block_size, cols)
        count = i2 - i1

        W1 = W[:, i1:i2].clone()
        Mask1 = torch.zeros_like(W1, dtype=torch.bool)
        Err1 = torch.zeros_like(W1)
        Hinv1 = Hinv[i1:i2, i1:i2]

        # 4) 블록 내 pruning 마스크 결정
        if self.mask_structure == "unstructured":
            # 중요도 기반 threshold 선정
            importance = W1.pow(2) / (torch.diag(Hinv1).unsqueeze(0) ** 2)
            threshold = torch.quantile(
                importance.flatten(),
                self.sparsity
            )
            Mask1 = importance > threshold

        elif self.mask_structure == "2:4":
            # 4 개 그룹 마다 상위 2 개만 유지
            Mask1 = _compute_2_4_mask(W1, Hinv1)

        # 5) 블록 내 OBS 보정 (pruning 된 위치는 0 설정)
        for j in range(count):
            w = W1[:, j]
            mask = Mask1[:, j]
            d = Hinv1[j, j]

            q = w * mask.float()    # 유지할 값만 남김
            err = (w - q) / d
            W1[:, j:] -= err.unsqueeze(1) * Hinv1[j, j:].unsqueeze(0)
            Err1[:, j] = err
            W1[:, j] = q

        # 6) 블록 외부에 lazy batch update (GPTQ 와 동일)
        W[:, i1:i2] = W1
        Mask[:, i1:i2] = Mask1
        W[:, i2:] -= Err1 @ Hinv[i1:i2, i2:]

    # 7) 최종 마스크 적용
    module.weight.data = W * Mask.float()

이 코드는 GPTQ의 quantize_weight구조적으로 거의 동일하다. 차이는 단 하나 — 양자화 함수 대신 pruning mask가 적용된다. 나머지(블록 순회, lazy batch update, Cholesky, dampening)는 모두 동일하다.

2:4 Mask 결정

def _compute_2_4_mask(W1, Hinv1):
    """
    For each row, 4 consecutive columns → keep top 2 by importance.
    """
    rows, cols = W1.shape
    mask = torch.zeros_like(W1, dtype=torch.bool)

    # 중요도 계산 (OBS 공식)
    importance = W1.pow(2) / (torch.diag(Hinv1).unsqueeze(0) ** 2)

    # 4 열씩 묶어서 top 2 선택
    for i in range(0, cols, 4):
        if i + 4 > cols:
            break
        group_importance = importance[:, i:i+4]
        _, topk = torch.topk(group_importance, k=2, dim=1)
        for r in range(rows):
            mask[r, i + topk[r]] = True

    return mask

각 행에 대해 4개 묶음마다 상위 2개를 유지한다. 이 제약 하에서 OBS 최적 해가 아닌 "그룹 내 local optimal"만 얻는다. 하지만 논문은 이 근사가 실용적으로 충분히 좋다는 것을 보인다.

왜 이 설계인가

1. GPTQ와의 코드 공유. SparseGPT는 GPTQ와 거의 같은 골격을 쓴다. block_size, dampening_frac, OBS 공식, lazy batch update 모두 동일. 차이는 "무엇으로 값을 대체하는가"(양자화 vs 0). 이 코드 중복은 의도적으로 남겨두어 각 알고리즘의 본질을 명확히 드러낸다.

2. N:M sparsity hardware-aware. 2:4 mask가 NVIDIA Tensor Core를 실제로 활용한다. 이 구조를 만드는 것이 SparseGPT의 실용적 가치다. unstructured pruning보다 정확도 손실이 약간 크지만 2배 가속을 얻는다.

3. preserve_sparsity_mask 옵션. 기존에 pruning된 체크포인트에 추가 sparsity를 주거나, fine-tuning으로 풀린 0을 복원할 때 사용한다. Sparsity-aware training 시나리오용.

4. GPTQ와 같은 sequential 의존. 헤시안을 누적해야 하므로 Sequential Pipeline이 필수. Basic 파이프라인으로는 동작 불가.

5. Block-wise OBS + lazy batch update. 전체를 한 번에 풀지 않고 블록 단위로 처리. 이 패턴이 GPTQ와 동일한 이유는 OBS의 계산량 문제 해결이 같기 때문이다.

마무리

SparseGPT는 llm-compressor에서 가장 정교한 pruning Modifier다. 헤시안 기반이라 메모리가 많이 들지만 정확도가 좋다. 다음 글은 계산량이 훨씬 적은 Wanda를 본다.

참고 자료

댓글

관련 포스트

llm-compressor 의 다른글