본문으로 건너뛰기

[llm-compressor] Magnitude Pruning: 크기 기반과 상수 희소성 Modifier

들어가며

Magnitude pruning은 가장 단순한 가지치기 방식이다. "가중치 절댓값이 작으면 덜 중요하므로 제거한다". 이 아이디어는 Han et al. (2015)의 Deep Compression 논문이 대중화했으며, LLM 시대에도 여전히 유효한 기본선(baseline)이다. llm-compressor의 src/llmcompressor/modifiers/pruning/magnitude/base.pysrc/llmcompressor/modifiers/pruning/constant/base.py 두 파일을 분석한다.

핵심 구조/코드 분석

MagnitudePruningModifier: Data-Free Pruning

class MagnitudePruningModifier(Modifier):
    """
    Simple magnitude-based pruning: prune smallest |W| values.
    Data-free — no calibration data required.
    """
    sparsity: float = 0.5
    mask_structure: str = "unstructured"
    targets: list[str] = field(default_factory=lambda: ["Linear"])
    ignore: list[str] = field(default_factory=list)

    def on_initialize(self, state: State, **kwargs) -> bool:
        for name, module in match_named_modules(
            state.model, self.targets, self.ignore
        ):
            if not hasattr(module, "weight"):
                continue

            W = module.weight.data
            # 중요도 = 절댓값
            importance = W.abs()

            # 마스크 생성
            mask = generate_mask(importance, self.sparsity, self.mask_structure)

            # 가중치에 마스크 적용
            module.weight.data *= mask.to(W.dtype)

        return True
파라미터 기본값 의미
sparsity 0.5 제거할 비율 (0.5 = 50%)
mask_structure "unstructured" "2:4", "4:8" 등 지원
targets ["Linear"] 대상 레이어 타입

전체 코드가 30줄이 안 된다. 로직은 간단하다.

  1. 대상 모듈 순회
  2. |W|를 중요도로 사용
  3. generate_mask로 threshold 또는 N:M 마스크 생성
  4. 가중치에 마스크 곱

캘리브레이션 데이터가 필요 없다. on_initialize 한 번에 모든 일이 끝난다. Pipeline Registry는 이 Modifier를 보고 자동으로 data_free 파이프라인을 선택할 수 있다. 실제로는 QuantizationModifier가 아니므로 추론이 sequential을 고르지만, 실제 실행은 on_initialize에서 모두 끝나므로 sequential pipeline의 루프가 no-op이 된다.

generate_mask: 마스크 팩토리

src/llmcompressor/modifiers/pruning/utils/pytorch/mask_factory.py의 함수다.

def generate_mask(
    importance: torch.Tensor,        # 중요도 텐서 (가중치와 같은 shape)
    sparsity: float,                 # 제거할 비율
    mask_structure: str = "unstructured",
) -> torch.Tensor:
    if mask_structure == "unstructured":
        # 전체에서 하위 sparsity 비율 제거
        k = int(importance.numel() * (1 - sparsity))
        if k == 0:
            return torch.zeros_like(importance, dtype=torch.bool)
        threshold = torch.topk(importance.flatten(), k, largest=True).values[-1]
        return importance >= threshold

    elif mask_structure == "2:4":
        return _apply_n_m_sparsity(importance, n=2, m=4)

    elif mask_structure == "4:8":
        return _apply_n_m_sparsity(importance, n=4, m=8)

    elif mask_structure.startswith("block:"):
        block_size = int(mask_structure.split(":")[1])
        return _apply_block_sparsity(importance, block_size)

    else:
        raise ValueError(f"Unknown mask_structure: {mask_structure}")


def _apply_n_m_sparsity(importance, n: int, m: int):
    """Keep top-n values within every m-element group"""
    rows, cols = importance.shape
    mask = torch.zeros_like(importance, dtype=torch.bool)

    for i in range(0, cols, m):
        end = min(i + m, cols)
        group = importance[:, i:end]
        if end - i < n:
            mask[:, i:end] = True   # 그룹이 m보다 작으면 모두 유지
            continue
        _, topk = torch.topk(group, k=n, dim=1, largest=True)
        for r in range(rows):
            mask[r, i + topk[r]] = True

    return mask

generate_mask가 통일된 인터페이스다. 모든 pruning Modifier(Wanda, SparseGPT, Magnitude)가 이 함수를 호출해 마스크를 만든다. 차이는 전달하는 importance 텐서뿐이다.

ConstantPruningModifier: 마스크 유지

class ConstantPruningModifier(Modifier):
    """
    Maintain existing sparsity mask during fine-tuning.
    Prevents previously pruned weights from being restored by gradient updates.
    """
    targets: list[str] = field(default_factory=lambda: ["Linear"])
    start: float = 0.0
    end: float | None = None

    def on_initialize(self, state: State, **kwargs) -> bool:
        # 현재 가중치의 0 위치를 마스크로 저장
        for name, module in match_named_modules(state.model, self.targets):
            if not hasattr(module, "weight"):
                continue
            mask = (module.weight.data != 0).to(torch.float32)
            module.register_buffer("_pruning_mask", mask, persistent=False)

        return True

    def on_update(self, state: State, event: Event, **kwargs):
        """Called after each optimizer step — re-apply mask to preserve zeros"""
        if event.type_ != EventType.OPTIM_POST_STEP:
            return

        for name, module in match_named_modules(state.model, self.targets):
            if hasattr(module, "_pruning_mask"):
                module.weight.data *= module._pruning_mask

    def on_finalize(self, state: State, **kwargs) -> bool:
        # 버퍼 제거
        for name, module in match_named_modules(state.model, self.targets):
            if hasattr(module, "_pruning_mask"):
                del module._pruning_mask
        return True

ConstantPruningModifierpruning을 수행하지 않는다. 이미 pruning된 체크포인트를 fine-tuning할 때, optimizer가 0이었던 가중치를 다시 0이 아닌 값으로 만드는 것을 방지한다.

핵심은 on_update다. 옵티마이저 스텝 직후(OPTIM_POST_STEP 이벤트)에 마스크를 재적용한다. module.weight.data *= module._pruning_mask는 0이었던 위치를 다시 0으로 되돌린다. 이는 "sparsity를 유지하면서 나머지 가중치를 학습"하는 sparse fine-tuning을 가능하게 한다.

세 pruning Modifier의 포지셔닝

Modifier 캘리브레이션 정확도 속도 사용처
MagnitudePruning 불필요 낮음 매우 빠름 기본선, 테스트
Wanda 필요 (L2 norm) 높음 빠름 대부분의 프로덕션
SparseGPT 필요 (헤시안) 약간 더 높음 느림 최고 정확도 필요 시
ConstantPruning 불필요 - - Sparse fine-tuning 유지

왜 이 설계인가

1. Magnitude는 data-free. 가장 단순한 케이스라 캘리브레이션 불필요. on_initialize 한 번에 모든 일이 끝난다. 코드 30줄, 디버깅 쉬움.

2. generate_mask 팩토리 재사용. 세 pruning Modifier가 모두 이 함수를 쓴다. N:M 구조 지원을 한 곳에서 관리하므로 추가 구조가 생기면 팩토리만 수정하면 된다.

3. ConstantPruningModifier의 별도 존재. "pruning"과 "pruning 상태 유지"는 다른 기능이다. Constant를 별도 Modifier로 두어 sparse fine-tuning 시 명시적으로 조합할 수 있다.

4. OPTIM_POST_STEP 활용. ConstantPruningModifier는 훈련 루프의 OPTIM_POST_STEP 이벤트를 사용한다. 이 이벤트는 옵티마이저가 가중치를 업데이트한 직후이므로, 마스크 재적용에 최적 시점이다.

5. persistent=False 버퍼. _pruning_mask를 non-persistent 버퍼로 등록해 state_dict 저장 시 포함되지 않게 한다. 체크포인트 파일에 이 중간 상태가 누출되지 않는다.

마무리

Magnitude Pruning은 pruning 세계의 "Hello World"다. 간단하지만 중요도 계산의 기본 패턴을 보여주며, generate_mask 팩토리의 일관된 인터페이스는 더 정교한 알고리즘의 기반이 된다. 이로써 Pruning 섹션이 끝났다. 다음 글부터는 Transform 섹션이다.

참고 자료

댓글

관련 포스트

llm-compressor 의 다른글