본문으로 건너뛰기

[llm-compressor] Wanda: 활성화 가중 노름 기반 가지치기

들어가며

Wanda(Pruning by Weights ANd Activations)는 2023년에 제안된 매우 단순한 가지치기 기법이다. SparseGPT가 헤시안 역행렬을 풀어 최적 가지치기를 찾는 반면, Wanda는 단 한 줄의 공식으로 중요도를 계산한다.

$$ \text{importance}{ij} = |W{ij}| \cdot |X_j|_2 $$

즉 각 가중치의 "중요도"는 가중치 절댓값과 해당 입력 채널의 활성화 L2 norm의 곱이다. 복잡한 역행렬 없이, forward pre-hook으로 활성화 norm만 수집하면 된다. 놀랍게도 이 단순한 공식이 SparseGPT와 비슷한 정확도를 달성한다. Wanda 논문과 llm-compressor의 src/llmcompressor/modifiers/pruning/wanda/base.py를 분석한다.

공식 문서

논문 핵심 내용

Wanda의 통찰은 "가중치의 절댓값만 보는 기존 magnitude pruning은 틀렸다"는 것이다. 가중치가 커도 입력 활성화가 0에 가까우면 출력에 기여가 없다. 반대로 작은 가중치라도 입력이 강하게 연결되면 중요하다. 이 두 요인을 곱한 것이 출력 영향도의 좋은 근사다.

$$ |W_{ij} \cdot x_j| \approx \text{output contribution of weight } W_{ij} $$

여러 샘플에 대해 평균을 내면 $|X_j|_2$로 귀결된다. 이 중요도로 각 행(출력 채널) 내에서 상위 k%를 유지한다.

Per-output-row pruning이 Wanda의 또 다른 특징이다. 전체에서 상위 k%가 아닌 "각 출력 뉴런마다 독립적으로" 상위 k%를 고른다. 이는 모든 출력 뉴런이 균등한 sparsity를 가지도록 보장해 불균일을 방지한다.

벤치마크 (논문 기준, LLaMA-65B)

방법 50% sparsity PPL
Dense (original) 3.78
Magnitude pruning 26.24
SparseGPT 4.28
Wanda 4.22

Wanda가 SparseGPT보다 살짝 좋은 결과를 보이면서도, 계산량은 비교할 수 없이 적다. SparseGPT가 수 시간 걸리는 작업을 Wanda는 수 분 만에 끝낸다.

핵심 구조/코드 분석

WandaPruningModifier 파라미터

class WandaPruningModifier(SGPTBaseModifier):
    """Implements Wanda from https://arxiv.org/abs/2306.11695"""

    # Wanda 는 block_size, dampening_frac 같은 OBS 파라미터가 없음
    # 베이스에서 상속: sparsity, mask_structure, targets, ignore

    # 활성화 L2 norm 수집용 버퍼
    _activation_norms: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(
        default_factory=dict
    )
    _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)
파라미터 기본값 의미
sparsity - 목표 희소성 (0.5 = 50%)
mask_structure "unstructured" 또는 "2:4", "4:8"
targets ["Linear"] 대상 레이어
ignore [] 제외

Wanda 전용 파라미터는 없다. OBS 기반이 아니므로 block_size, dampening_frac 같은 튜닝이 불필요하다.

on_start: L2 norm 누적 훅 등록

def on_start(self, state: State, event: Event, **kwargs):
    QuantizationMixin.start_calibration(self, state.model)

    for name, module in match_named_modules(state.model, self.targets):
        if not isinstance(module, torch.nn.Linear):
            continue

        self._module_names[module] = name
        # 입력 채널별 누적 L2 norm^2
        self._activation_norms[module] = torch.zeros(
            module.in_features, dtype=torch.float32
        )
        self._num_samples[module] = 0

        def _hook(mod, args):
            x = args[0] if isinstance(args, tuple) else args
            if isinstance(x, tuple):
                x = x[0]

            x_f = x.detach().float()
            # 각 입력 채널의 norm² 누적
            # x.shape = (..., in_features) → 마지막 차원 유지하고 나머지 reduce
            norm_sq = x_f.pow(2).sum(dim=tuple(range(x_f.dim() - 1)))
            n_tokens = math.prod(x_f.shape[:-1])

            self._activation_norms[mod] = self._activation_norms[mod].to(norm_sq.device)
            self._activation_norms[mod].add_(norm_sq)
            self._num_samples[mod] += n_tokens

        module.register_forward_pre_hook(_hook)

이 훅은 각 Linear에 대해 "입력 채널별 제곱합"을 누적한다. 모든 배치를 순회한 뒤 sqrt(sum / num_samples)를 하면 채널별 L2 norm이 된다.

FP32로 누적하는 것은 iMatrix Observer와 같은 이유다. BF16/FP16은 제곱합이 빠르게 오버플로우될 수 있다.

on_event로 레이어별 pruning

def on_event(self, state: State, event: Event, **kwargs):
    if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
        subgraph = kwargs["subgraph"]
        for module_name in subgraph.consumed_names:
            module = state.model.get_submodule(module_name)
            if module in self._activation_norms:
                self._prune_module(module)


def _prune_module(self, module):
    W = module.weight.data      # shape (out_features, in_features)
    norm_sq = self._activation_norms[module]
    num_samples = self._num_samples[module]

    # 평균 L2 norm
    activation_norm = torch.sqrt(norm_sq / num_samples)   # shape (in_features,)

    # Wanda importance: |W| * ||X||_2 (broadcast)
    importance = W.abs() * activation_norm.unsqueeze(0)
    # importance.shape == W.shape

    # Per-row pruning: 각 출력 뉴런마다 독립적으로 top-(1-sparsity)% 유지
    if self.mask_structure == "unstructured":
        rows, cols = W.shape
        k_keep = int(cols * (1 - self.sparsity))

        # 각 행에서 상위 k_keep 개의 인덱스
        _, top_idx = torch.topk(importance, k_keep, dim=1, largest=True)

        mask = torch.zeros_like(W, dtype=torch.bool)
        mask.scatter_(1, top_idx, True)

    elif self.mask_structure == "2:4":
        mask = _compute_2_4_wanda_mask(importance)

    # Apply mask
    module.weight.data = W * mask.float()

    # Clean up
    del self._activation_norms[module]
    del self._num_samples[module]

핵심은 importance = W.abs() * activation_norm.unsqueeze(0) 한 줄이다. 이 broadcast 곱으로 (out_features, in_features) 크기의 중요도 텐서가 만들어진다. activation_norm(in_features,)이므로 각 열에 그 열의 norm이 곱해진다.

torch.topk(importance, k_keep, dim=1)은 각 행에서 독립적으로 상위 k_keep 개를 고른다. 이것이 per-output-row pruning이다. scatter_로 그 위치들만 True인 마스크를 생성한다.

2:4 마스크는 SparseGPT와 유사

def _compute_2_4_wanda_mask(importance):
    """4개 묶음마다 상위 2개 유지"""
    rows, cols = importance.shape
    mask = torch.zeros_like(importance, dtype=torch.bool)

    for i in range(0, cols, 4):
        if i + 4 > cols:
            break
        group = importance[:, i:i+4]
        _, topk_idx = torch.topk(group, k=2, dim=1, largest=True)
        for r in range(rows):
            mask[r, i + topk_idx[r]] = True

    return mask

구조는 SparseGPT의 _compute_2_4_mask와 동일하지만, 중요도 공식이 다르다. SparseGPT는 OBS 기반, Wanda는 단순 곱.

왜 이 설계인가

1. 단순함의 미덕. Wanda의 전체 pruning 로직이 30줄 안에 들어간다. 역행렬도, 블록 순회도, lazy batch update도 없다. 구현 실수 여지가 거의 없다.

2. Forward pre-hook만으로 충분. 활성화 norm을 누적하는 것이 유일한 캘리브레이션 작업이다. 메모리 부담이 적고 속도가 빠르다. 70B 모델도 수 분 안에 끝난다.

3. Per-row pruning. 전체 레이어에서 상위 k%가 아니라 각 출력 행에서 독립적으로 상위 k%를 고른다. 이는 "일부 뉴런이 너무 많이 pruning되는" 불균형을 방지한다.

4. SparseGPT와 비교해 정확도 우위. LLaMA-65B에서 Wanda가 살짝 좋았다(4.22 vs 4.28). 이는 "더 복잡한 알고리즘이 더 좋다"는 가정을 반박한다. 단순함이 일관성 있게 작동하는 경우가 많다.

5. Sequential Pipeline 의존. Wanda도 활성화 통계가 필요하므로 data_free로는 동작 불가. 단 GPTQ/SparseGPT와 달리 비용이 매우 낮으므로 sequential pipeline의 오버헤드가 거의 느껴지지 않는다.

마무리

Wanda는 "단순한 것이 좋다"의 교본이다. 한 줄의 수식이 복잡한 OBS 구현을 대체한다. 다음 글은 가장 단순한 Magnitude Pruning을 본다.

참고 자료

댓글

관련 포스트

llm-compressor 의 다른글