[llm-compressor] SparseGPT: 원샷 LLM 가지치기 구현
들어가며
SparseGPT는 GPTQ와 같은 연구진이 2023년에 발표한 "양자화의 자매 알고리즘"이다. GPTQ가 "가중치를 정수로 반올림"한다면 SparseGPT는 "일부 가중치를 0으로 만든다". 두 알고리즘은 같은 OBS(Optimal Brain Surgeon) 프레임워크를 공유하며, 코드 구조도 거의 동일하다. OPT-175B를 4시간 만에 50% sparsity로 가지치기하면서 perplexity 손실이 무시할 만한 수준이라는 것이 주요 결과다. SparseGPT 논문과 llm-compressor의 src/llmcompressor/modifiers/pruning/sparsegpt/base.py를 분석한다.
공식 문서
- 논문: SparseGPT: Massive Language Models Can be Accurately Pruned in One-Shot
- 예제: examples/sparse_2of4_quantization_fp8/
논문 핵심 내용
SparseGPT의 기반은 OBS다. "한 가중치를 0으로 만들면 다른 가중치들을 얼마나 보정해야 하는가"를 2차 정보(헤시안)로 계산한다.
$$ \text{importance}i = \frac{w_i^2}{[H^{-1}]{ii}} $$
이 중요도가 낮은 가중치를 우선 제거하고, 남은 가중치를 OBS 공식으로 보정한다.
$$ w_j \leftarrow w_j - w_i \cdot \frac{[H^{-1}]{ij}}{[H^{-1}]{ii}} $$
GPTQ의 양자화 공식과 동일한 수식이다. 차이는 양자화가 "$w_i$를 반올림 값으로 교체"한다면, pruning은 "$w_i = 0$"으로 만든다는 것뿐이다. 두 연산을 통합해 볼 수 있고, 논문도 이를 명시적으로 언급한다.
N:M Sparsity 지원
SparseGPT의 중요한 특징은 2:4 sparsity 같은 N:M 구조를 지원한다는 것이다. 4개마다 2개를 0으로 만드는 제약 조건 하에서 OBS 최적해를 찾는다. 이는 NVIDIA A100 이상에서 하드웨어 가속을 받을 수 있어 실용성이 높다.
벤치마크 (OPT-175B, 논문 기준)
| 항목 | 수치 |
|---|---|
| OPT-175B 가지치기 시간 | ~4 GPU 시간 |
| 50% unstructured sparsity PPL 손실 | 무시할 수준 |
| 2:4 sparsity PPL 손실 | 약간 증가하나 여전히 수용 가능 |
| 지원 sparsity | 50%, 60%, 70%, 2:4, 4:8 |
| 이전 방법 대비 | OBS를 175B 모델에 적용한 최초 사례 |
핵심 구조/코드 분석
SparseGPTModifier 파라미터
class SparseGPTModifier(SGPTBaseModifier):
"""Implements SparseGPT from https://arxiv.org/abs/2301.00774"""
# SparseGPT 전용 파라미터
block_size: int = 128 # GPTQ 와 동일한 의미
dampening_frac: float = 0.01 # 헤시안 대각 정규화
preserve_sparsity_mask: bool = False # 기존 마스크 보존 모드
# 베이스에서 상속: sparsity, mask_structure, targets, ignore, _hessians
| 파라미터 | 기본값 | 의미 |
|---|---|---|
block_size |
128 | OBS 블록 업데이트 단위 (GPTQ와 동일) |
dampening_frac |
0.01 | 헤시안 대각 정규화 |
preserve_sparsity_mask |
False | 기존 0 위치를 유지하며 추가 가지치기 |
sparsity |
- | 목표 희소성 (예: 0.5 = 50%) |
mask_structure |
"unstructured" |
"2:4", "4:8" 등 |
on_start / on_sequential_epoch_end / on_finalize
SparseGPT의 라이프사이클 훅들은 GPTQ와 거의 동일하다.
def on_start(self, state: State, event: Event, **kwargs):
# 대상 모듈 각각에 헤시안 누적 hook 등록 (GPTQ 와 동일)
for name, module in match_named_modules(state.model, self.targets):
self._module_names[module] = name
self._hessians[module] = make_empty_hessian(module)
self._num_samples[module] = 0
def _hook(mod, args):
x = args[0]
accumulate_hessian(self._hessians[mod], self._num_samples, mod, x)
module.register_forward_pre_hook(_hook)
def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
subgraph = kwargs["subgraph"]
for module_name in subgraph.consumed_names:
module = state.model.get_submodule(module_name)
if module in self._hessians:
self._sparsify_module(module)
del self._hessians[module]
del self._num_samples[module]
차이는 _sparsify_module에 있다.
_sparsify_module: OBS 기반 pruning
def _sparsify_module(self, module):
W = module.weight.data.clone().float()
H = self._hessians[module].float()
rows, cols = W.shape
# 1) 헤시안 정규화
damp = self.dampening_frac * torch.mean(torch.diag(H))
H[torch.arange(cols), torch.arange(cols)] += damp
# 2) Cholesky 역행렬 (GPTQ 와 동일)
Hinv = torch.linalg.cholesky(torch.linalg.inv(H), upper=True)
# 3) block_size 단위로 순회
Mask = torch.zeros_like(W, dtype=torch.bool)
for i1 in range(0, cols, self.block_size):
i2 = min(i1 + self.block_size, cols)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Mask1 = torch.zeros_like(W1, dtype=torch.bool)
Err1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
# 4) 블록 내 pruning 마스크 결정
if self.mask_structure == "unstructured":
# 중요도 기반 threshold 선정
importance = W1.pow(2) / (torch.diag(Hinv1).unsqueeze(0) ** 2)
threshold = torch.quantile(
importance.flatten(),
self.sparsity
)
Mask1 = importance > threshold
elif self.mask_structure == "2:4":
# 4 개 그룹 마다 상위 2 개만 유지
Mask1 = _compute_2_4_mask(W1, Hinv1)
# 5) 블록 내 OBS 보정 (pruning 된 위치는 0 설정)
for j in range(count):
w = W1[:, j]
mask = Mask1[:, j]
d = Hinv1[j, j]
q = w * mask.float() # 유지할 값만 남김
err = (w - q) / d
W1[:, j:] -= err.unsqueeze(1) * Hinv1[j, j:].unsqueeze(0)
Err1[:, j] = err
W1[:, j] = q
# 6) 블록 외부에 lazy batch update (GPTQ 와 동일)
W[:, i1:i2] = W1
Mask[:, i1:i2] = Mask1
W[:, i2:] -= Err1 @ Hinv[i1:i2, i2:]
# 7) 최종 마스크 적용
module.weight.data = W * Mask.float()
이 코드는 GPTQ의 quantize_weight와 구조적으로 거의 동일하다. 차이는 단 하나 — 양자화 함수 대신 pruning mask가 적용된다. 나머지(블록 순회, lazy batch update, Cholesky, dampening)는 모두 동일하다.
2:4 Mask 결정
def _compute_2_4_mask(W1, Hinv1):
"""
For each row, 4 consecutive columns → keep top 2 by importance.
"""
rows, cols = W1.shape
mask = torch.zeros_like(W1, dtype=torch.bool)
# 중요도 계산 (OBS 공식)
importance = W1.pow(2) / (torch.diag(Hinv1).unsqueeze(0) ** 2)
# 4 열씩 묶어서 top 2 선택
for i in range(0, cols, 4):
if i + 4 > cols:
break
group_importance = importance[:, i:i+4]
_, topk = torch.topk(group_importance, k=2, dim=1)
for r in range(rows):
mask[r, i + topk[r]] = True
return mask
각 행에 대해 4개 묶음마다 상위 2개를 유지한다. 이 제약 하에서 OBS 최적 해가 아닌 "그룹 내 local optimal"만 얻는다. 하지만 논문은 이 근사가 실용적으로 충분히 좋다는 것을 보인다.
왜 이 설계인가
1. GPTQ와의 코드 공유. SparseGPT는 GPTQ와 거의 같은 골격을 쓴다. block_size, dampening_frac, OBS 공식, lazy batch update 모두 동일. 차이는 "무엇으로 값을 대체하는가"(양자화 vs 0). 이 코드 중복은 의도적으로 남겨두어 각 알고리즘의 본질을 명확히 드러낸다.
2. N:M sparsity hardware-aware. 2:4 mask가 NVIDIA Tensor Core를 실제로 활용한다. 이 구조를 만드는 것이 SparseGPT의 실용적 가치다. unstructured pruning보다 정확도 손실이 약간 크지만 2배 가속을 얻는다.
3. preserve_sparsity_mask 옵션. 기존에 pruning된 체크포인트에 추가 sparsity를 주거나, fine-tuning으로 풀린 0을 복원할 때 사용한다. Sparsity-aware training 시나리오용.
4. GPTQ와 같은 sequential 의존. 헤시안을 누적해야 하므로 Sequential Pipeline이 필수. Basic 파이프라인으로는 동작 불가.
5. Block-wise OBS + lazy batch update. 전체를 한 번에 풀지 않고 블록 단위로 처리. 이 패턴이 GPTQ와 동일한 이유는 OBS의 계산량 문제 해결이 같기 때문이다.
마무리
SparseGPT는 llm-compressor에서 가장 정교한 pruning Modifier다. 헤시안 기반이라 메모리가 많이 들지만 정확도가 좋다. 다음 글은 계산량이 훨씬 적은 Wanda를 본다.
참고 자료
관련 포스트
llm-compressor 의 다른글
- 이전글 [llm-compressor] Pruning Overview: OBCQ 계열 Modifier 구조
- 현재글 : [llm-compressor] SparseGPT: 원샷 LLM 가지치기 구현
- 다음글 [llm-compressor] Wanda: 활성화 가중 노름 기반 가지치기
댓글