[llm-compressor] Wanda: 활성화 가중 노름 기반 가지치기
들어가며
Wanda(Pruning by Weights ANd Activations)는 2023년에 제안된 매우 단순한 가지치기 기법이다. SparseGPT가 헤시안 역행렬을 풀어 최적 가지치기를 찾는 반면, Wanda는 단 한 줄의 공식으로 중요도를 계산한다.
$$ \text{importance}{ij} = |W{ij}| \cdot |X_j|_2 $$
즉 각 가중치의 "중요도"는 가중치 절댓값과 해당 입력 채널의 활성화 L2 norm의 곱이다. 복잡한 역행렬 없이, forward pre-hook으로 활성화 norm만 수집하면 된다. 놀랍게도 이 단순한 공식이 SparseGPT와 비슷한 정확도를 달성한다. Wanda 논문과 llm-compressor의 src/llmcompressor/modifiers/pruning/wanda/base.py를 분석한다.
공식 문서
- 논문: A Simple and Effective Pruning Approach for Large Language Models
- 예제: examples/sparse_2of4_quantization_fp8/
논문 핵심 내용
Wanda의 통찰은 "가중치의 절댓값만 보는 기존 magnitude pruning은 틀렸다"는 것이다. 가중치가 커도 입력 활성화가 0에 가까우면 출력에 기여가 없다. 반대로 작은 가중치라도 입력이 강하게 연결되면 중요하다. 이 두 요인을 곱한 것이 출력 영향도의 좋은 근사다.
$$ |W_{ij} \cdot x_j| \approx \text{output contribution of weight } W_{ij} $$
여러 샘플에 대해 평균을 내면 $|X_j|_2$로 귀결된다. 이 중요도로 각 행(출력 채널) 내에서 상위 k%를 유지한다.
Per-output-row pruning이 Wanda의 또 다른 특징이다. 전체에서 상위 k%가 아닌 "각 출력 뉴런마다 독립적으로" 상위 k%를 고른다. 이는 모든 출력 뉴런이 균등한 sparsity를 가지도록 보장해 불균일을 방지한다.
벤치마크 (논문 기준, LLaMA-65B)
| 방법 | 50% sparsity PPL |
|---|---|
| Dense (original) | 3.78 |
| Magnitude pruning | 26.24 |
| SparseGPT | 4.28 |
| Wanda | 4.22 |
Wanda가 SparseGPT보다 살짝 좋은 결과를 보이면서도, 계산량은 비교할 수 없이 적다. SparseGPT가 수 시간 걸리는 작업을 Wanda는 수 분 만에 끝낸다.
핵심 구조/코드 분석
WandaPruningModifier 파라미터
class WandaPruningModifier(SGPTBaseModifier):
"""Implements Wanda from https://arxiv.org/abs/2306.11695"""
# Wanda 는 block_size, dampening_frac 같은 OBS 파라미터가 없음
# 베이스에서 상속: sparsity, mask_structure, targets, ignore
# 활성화 L2 norm 수집용 버퍼
_activation_norms: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(
default_factory=dict
)
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)
| 파라미터 | 기본값 | 의미 |
|---|---|---|
sparsity |
- | 목표 희소성 (0.5 = 50%) |
mask_structure |
"unstructured" |
또는 "2:4", "4:8" |
targets |
["Linear"] |
대상 레이어 |
ignore |
[] |
제외 |
Wanda 전용 파라미터는 없다. OBS 기반이 아니므로 block_size, dampening_frac 같은 튜닝이 불필요하다.
on_start: L2 norm 누적 훅 등록
def on_start(self, state: State, event: Event, **kwargs):
QuantizationMixin.start_calibration(self, state.model)
for name, module in match_named_modules(state.model, self.targets):
if not isinstance(module, torch.nn.Linear):
continue
self._module_names[module] = name
# 입력 채널별 누적 L2 norm^2
self._activation_norms[module] = torch.zeros(
module.in_features, dtype=torch.float32
)
self._num_samples[module] = 0
def _hook(mod, args):
x = args[0] if isinstance(args, tuple) else args
if isinstance(x, tuple):
x = x[0]
x_f = x.detach().float()
# 각 입력 채널의 norm² 누적
# x.shape = (..., in_features) → 마지막 차원 유지하고 나머지 reduce
norm_sq = x_f.pow(2).sum(dim=tuple(range(x_f.dim() - 1)))
n_tokens = math.prod(x_f.shape[:-1])
self._activation_norms[mod] = self._activation_norms[mod].to(norm_sq.device)
self._activation_norms[mod].add_(norm_sq)
self._num_samples[mod] += n_tokens
module.register_forward_pre_hook(_hook)
이 훅은 각 Linear에 대해 "입력 채널별 제곱합"을 누적한다. 모든 배치를 순회한 뒤 sqrt(sum / num_samples)를 하면 채널별 L2 norm이 된다.
FP32로 누적하는 것은 iMatrix Observer와 같은 이유다. BF16/FP16은 제곱합이 빠르게 오버플로우될 수 있다.
on_event로 레이어별 pruning
def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
subgraph = kwargs["subgraph"]
for module_name in subgraph.consumed_names:
module = state.model.get_submodule(module_name)
if module in self._activation_norms:
self._prune_module(module)
def _prune_module(self, module):
W = module.weight.data # shape (out_features, in_features)
norm_sq = self._activation_norms[module]
num_samples = self._num_samples[module]
# 평균 L2 norm
activation_norm = torch.sqrt(norm_sq / num_samples) # shape (in_features,)
# Wanda importance: |W| * ||X||_2 (broadcast)
importance = W.abs() * activation_norm.unsqueeze(0)
# importance.shape == W.shape
# Per-row pruning: 각 출력 뉴런마다 독립적으로 top-(1-sparsity)% 유지
if self.mask_structure == "unstructured":
rows, cols = W.shape
k_keep = int(cols * (1 - self.sparsity))
# 각 행에서 상위 k_keep 개의 인덱스
_, top_idx = torch.topk(importance, k_keep, dim=1, largest=True)
mask = torch.zeros_like(W, dtype=torch.bool)
mask.scatter_(1, top_idx, True)
elif self.mask_structure == "2:4":
mask = _compute_2_4_wanda_mask(importance)
# Apply mask
module.weight.data = W * mask.float()
# Clean up
del self._activation_norms[module]
del self._num_samples[module]
핵심은 importance = W.abs() * activation_norm.unsqueeze(0) 한 줄이다. 이 broadcast 곱으로 (out_features, in_features) 크기의 중요도 텐서가 만들어진다. activation_norm은 (in_features,)이므로 각 열에 그 열의 norm이 곱해진다.
torch.topk(importance, k_keep, dim=1)은 각 행에서 독립적으로 상위 k_keep 개를 고른다. 이것이 per-output-row pruning이다. scatter_로 그 위치들만 True인 마스크를 생성한다.
2:4 마스크는 SparseGPT와 유사
def _compute_2_4_wanda_mask(importance):
"""4개 묶음마다 상위 2개 유지"""
rows, cols = importance.shape
mask = torch.zeros_like(importance, dtype=torch.bool)
for i in range(0, cols, 4):
if i + 4 > cols:
break
group = importance[:, i:i+4]
_, topk_idx = torch.topk(group, k=2, dim=1, largest=True)
for r in range(rows):
mask[r, i + topk_idx[r]] = True
return mask
구조는 SparseGPT의 _compute_2_4_mask와 동일하지만, 중요도 공식이 다르다. SparseGPT는 OBS 기반, Wanda는 단순 곱.
왜 이 설계인가
1. 단순함의 미덕. Wanda의 전체 pruning 로직이 30줄 안에 들어간다. 역행렬도, 블록 순회도, lazy batch update도 없다. 구현 실수 여지가 거의 없다.
2. Forward pre-hook만으로 충분. 활성화 norm을 누적하는 것이 유일한 캘리브레이션 작업이다. 메모리 부담이 적고 속도가 빠르다. 70B 모델도 수 분 안에 끝난다.
3. Per-row pruning. 전체 레이어에서 상위 k%가 아니라 각 출력 행에서 독립적으로 상위 k%를 고른다. 이는 "일부 뉴런이 너무 많이 pruning되는" 불균형을 방지한다.
4. SparseGPT와 비교해 정확도 우위. LLaMA-65B에서 Wanda가 살짝 좋았다(4.22 vs 4.28). 이는 "더 복잡한 알고리즘이 더 좋다"는 가정을 반박한다. 단순함이 일관성 있게 작동하는 경우가 많다.
5. Sequential Pipeline 의존. Wanda도 활성화 통계가 필요하므로 data_free로는 동작 불가. 단 GPTQ/SparseGPT와 달리 비용이 매우 낮으므로 sequential pipeline의 오버헤드가 거의 느껴지지 않는다.
마무리
Wanda는 "단순한 것이 좋다"의 교본이다. 한 줄의 수식이 복잡한 OBS 구현을 대체한다. 다음 글은 가장 단순한 Magnitude Pruning을 본다.
참고 자료
관련 포스트
llm-compressor 의 다른글
- 이전글 [llm-compressor] SparseGPT: 원샷 LLM 가지치기 구현
- 현재글 : [llm-compressor] Wanda: 활성화 가중 노름 기반 가지치기
- 다음글 [llm-compressor] Magnitude Pruning: 크기 기반과 상수 희소성 Modifier
댓글