[llm-compressor] Magnitude Pruning: 크기 기반과 상수 희소성 Modifier
들어가며
Magnitude pruning은 가장 단순한 가지치기 방식이다. "가중치 절댓값이 작으면 덜 중요하므로 제거한다". 이 아이디어는 Han et al. (2015)의 Deep Compression 논문이 대중화했으며, LLM 시대에도 여전히 유효한 기본선(baseline)이다. llm-compressor의 src/llmcompressor/modifiers/pruning/magnitude/base.py와 src/llmcompressor/modifiers/pruning/constant/base.py 두 파일을 분석한다.
핵심 구조/코드 분석
MagnitudePruningModifier: Data-Free Pruning
class MagnitudePruningModifier(Modifier):
"""
Simple magnitude-based pruning: prune smallest |W| values.
Data-free — no calibration data required.
"""
sparsity: float = 0.5
mask_structure: str = "unstructured"
targets: list[str] = field(default_factory=lambda: ["Linear"])
ignore: list[str] = field(default_factory=list)
def on_initialize(self, state: State, **kwargs) -> bool:
for name, module in match_named_modules(
state.model, self.targets, self.ignore
):
if not hasattr(module, "weight"):
continue
W = module.weight.data
# 중요도 = 절댓값
importance = W.abs()
# 마스크 생성
mask = generate_mask(importance, self.sparsity, self.mask_structure)
# 가중치에 마스크 적용
module.weight.data *= mask.to(W.dtype)
return True
| 파라미터 | 기본값 | 의미 |
|---|---|---|
sparsity |
0.5 | 제거할 비율 (0.5 = 50%) |
mask_structure |
"unstructured" |
"2:4", "4:8" 등 지원 |
targets |
["Linear"] |
대상 레이어 타입 |
전체 코드가 30줄이 안 된다. 로직은 간단하다.
- 대상 모듈 순회
|W|를 중요도로 사용generate_mask로 threshold 또는 N:M 마스크 생성- 가중치에 마스크 곱
캘리브레이션 데이터가 필요 없다. on_initialize 한 번에 모든 일이 끝난다. Pipeline Registry는 이 Modifier를 보고 자동으로 data_free 파이프라인을 선택할 수 있다. 실제로는 QuantizationModifier가 아니므로 추론이 sequential을 고르지만, 실제 실행은 on_initialize에서 모두 끝나므로 sequential pipeline의 루프가 no-op이 된다.
generate_mask: 마스크 팩토리
src/llmcompressor/modifiers/pruning/utils/pytorch/mask_factory.py의 함수다.
def generate_mask(
importance: torch.Tensor, # 중요도 텐서 (가중치와 같은 shape)
sparsity: float, # 제거할 비율
mask_structure: str = "unstructured",
) -> torch.Tensor:
if mask_structure == "unstructured":
# 전체에서 하위 sparsity 비율 제거
k = int(importance.numel() * (1 - sparsity))
if k == 0:
return torch.zeros_like(importance, dtype=torch.bool)
threshold = torch.topk(importance.flatten(), k, largest=True).values[-1]
return importance >= threshold
elif mask_structure == "2:4":
return _apply_n_m_sparsity(importance, n=2, m=4)
elif mask_structure == "4:8":
return _apply_n_m_sparsity(importance, n=4, m=8)
elif mask_structure.startswith("block:"):
block_size = int(mask_structure.split(":")[1])
return _apply_block_sparsity(importance, block_size)
else:
raise ValueError(f"Unknown mask_structure: {mask_structure}")
def _apply_n_m_sparsity(importance, n: int, m: int):
"""Keep top-n values within every m-element group"""
rows, cols = importance.shape
mask = torch.zeros_like(importance, dtype=torch.bool)
for i in range(0, cols, m):
end = min(i + m, cols)
group = importance[:, i:end]
if end - i < n:
mask[:, i:end] = True # 그룹이 m보다 작으면 모두 유지
continue
_, topk = torch.topk(group, k=n, dim=1, largest=True)
for r in range(rows):
mask[r, i + topk[r]] = True
return mask
generate_mask가 통일된 인터페이스다. 모든 pruning Modifier(Wanda, SparseGPT, Magnitude)가 이 함수를 호출해 마스크를 만든다. 차이는 전달하는 importance 텐서뿐이다.
ConstantPruningModifier: 마스크 유지
class ConstantPruningModifier(Modifier):
"""
Maintain existing sparsity mask during fine-tuning.
Prevents previously pruned weights from being restored by gradient updates.
"""
targets: list[str] = field(default_factory=lambda: ["Linear"])
start: float = 0.0
end: float | None = None
def on_initialize(self, state: State, **kwargs) -> bool:
# 현재 가중치의 0 위치를 마스크로 저장
for name, module in match_named_modules(state.model, self.targets):
if not hasattr(module, "weight"):
continue
mask = (module.weight.data != 0).to(torch.float32)
module.register_buffer("_pruning_mask", mask, persistent=False)
return True
def on_update(self, state: State, event: Event, **kwargs):
"""Called after each optimizer step — re-apply mask to preserve zeros"""
if event.type_ != EventType.OPTIM_POST_STEP:
return
for name, module in match_named_modules(state.model, self.targets):
if hasattr(module, "_pruning_mask"):
module.weight.data *= module._pruning_mask
def on_finalize(self, state: State, **kwargs) -> bool:
# 버퍼 제거
for name, module in match_named_modules(state.model, self.targets):
if hasattr(module, "_pruning_mask"):
del module._pruning_mask
return True
ConstantPruningModifier는 pruning을 수행하지 않는다. 이미 pruning된 체크포인트를 fine-tuning할 때, optimizer가 0이었던 가중치를 다시 0이 아닌 값으로 만드는 것을 방지한다.
핵심은 on_update다. 옵티마이저 스텝 직후(OPTIM_POST_STEP 이벤트)에 마스크를 재적용한다. module.weight.data *= module._pruning_mask는 0이었던 위치를 다시 0으로 되돌린다. 이는 "sparsity를 유지하면서 나머지 가중치를 학습"하는 sparse fine-tuning을 가능하게 한다.
세 pruning Modifier의 포지셔닝
| Modifier | 캘리브레이션 | 정확도 | 속도 | 사용처 |
|---|---|---|---|---|
| MagnitudePruning | 불필요 | 낮음 | 매우 빠름 | 기본선, 테스트 |
| Wanda | 필요 (L2 norm) | 높음 | 빠름 | 대부분의 프로덕션 |
| SparseGPT | 필요 (헤시안) | 약간 더 높음 | 느림 | 최고 정확도 필요 시 |
| ConstantPruning | 불필요 | - | - | Sparse fine-tuning 유지 |
왜 이 설계인가
1. Magnitude는 data-free. 가장 단순한 케이스라 캘리브레이션 불필요. on_initialize 한 번에 모든 일이 끝난다. 코드 30줄, 디버깅 쉬움.
2. generate_mask 팩토리 재사용. 세 pruning Modifier가 모두 이 함수를 쓴다. N:M 구조 지원을 한 곳에서 관리하므로 추가 구조가 생기면 팩토리만 수정하면 된다.
3. ConstantPruningModifier의 별도 존재. "pruning"과 "pruning 상태 유지"는 다른 기능이다. Constant를 별도 Modifier로 두어 sparse fine-tuning 시 명시적으로 조합할 수 있다.
4. OPTIM_POST_STEP 활용. ConstantPruningModifier는 훈련 루프의 OPTIM_POST_STEP 이벤트를 사용한다. 이 이벤트는 옵티마이저가 가중치를 업데이트한 직후이므로, 마스크 재적용에 최적 시점이다.
5. persistent=False 버퍼. _pruning_mask를 non-persistent 버퍼로 등록해 state_dict 저장 시 포함되지 않게 한다. 체크포인트 파일에 이 중간 상태가 누출되지 않는다.
마무리
Magnitude Pruning은 pruning 세계의 "Hello World"다. 간단하지만 중요도 계산의 기본 패턴을 보여주며, generate_mask 팩토리의 일관된 인터페이스는 더 정교한 알고리즘의 기반이 된다. 이로써 Pruning 섹션이 끝났다. 다음 글부터는 Transform 섹션이다.
참고 자료
관련 포스트
llm-compressor 의 다른글
- 이전글 [llm-compressor] Wanda: 활성화 가중 노름 기반 가지치기
- 현재글 : [llm-compressor] Magnitude Pruning: 크기 기반과 상수 희소성 Modifier
- 다음글 [llm-compressor] Transform Overview: 가중치 회전/변환 기반 Modifier 계열
댓글