[llm-compressor] PyTorch Utils: 희소화 통계와 모듈 헬퍼
들어가며
llm-compressor는 여러 곳에서 "모델 전체를 순회하며 무언가를 집계하거나 변경"하는 작업을 한다. 이런 공통 작업을 위한 헬퍼가 src/llmcompressor/pytorch/와 src/llmcompressor/utils/pytorch/에 있다. 이 글은 그 중 주요 유틸리티를 본다.
핵심 구조/코드 분석
디렉토리 구성
pytorch/
├── model_load/
│ └── helpers.py
└── utils/
├── helpers.py
├── sparsification.py
└── sparsification_info/
├── configs.py
├── helpers.py
└── module_sparsification_info.py
utils/
└── pytorch/
├── module.py
└── utils.py
이 디렉토리 분리는 역사적 유산이다. 최신 코드는 utils/pytorch/를 주로 쓰고, pytorch/utils/는 sparsification 통계 등 특수 용도에 사용된다.
utils/pytorch/module.py: 모듈 순회 헬퍼
def get_matching_layer(
model: torch.nn.Module,
pattern: str,
) -> Optional[torch.nn.Module]:
"""Get the first module matching a regex-like pattern"""
for name, module in model.named_modules():
if re.match(pattern, name):
return module
return None
def get_layer_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module:
"""Traverse dotted name path to retrieve nested submodule"""
attrs = name.split(".")
current = model
for attr in attrs:
current = getattr(current, attr)
return current
def infer_sequential_targets(
model: torch.nn.Module,
user_targets: Optional[list[str]] = None,
) -> list[str]:
"""
Determine which module types to use as boundaries for sequential pipeline.
Falls back to HF model's `no_split_modules` attribute.
"""
if user_targets:
return user_targets
# HF 모델의 _no_split_modules 속성 조회
if hasattr(model, "_no_split_modules"):
return model._no_split_modules
# 대체 경로: decoder layer 를 자동 탐지
return _auto_detect_decoder_layers(model)
이 유틸리티들은 llm-compressor 전반에서 쓰이는 "모듈 navigation" 기본기다.
get_matching_layer: 정규식으로 모듈 검색get_layer_by_name: 점 표기법("model.layers.0.self_attn.q_proj") 경로로 nested 접근infer_sequential_targets: Sequential Pipeline이 모델을 쪼갤 때 경계 타입 자동 결정
_no_split_modules는 HuggingFace Transformers가 모델에 붙여두는 속성이다. 예를 들어 LLaMA는 ["LlamaDecoderLayer"]로 설정되어 있다. llm-compressor는 이를 활용해 "사용자가 sequential_targets를 명시하지 않으면 HF가 제공한 힌트를 그대로 쓴다".
pytorch/utils/sparsification.py: 희소화 통계
def get_prunable_layers(model: torch.nn.Module) -> list[tuple[str, torch.nn.Module]]:
"""Return all layers that can be pruned (Linear, Conv2d, etc.)"""
prunable = []
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
prunable.append((name, module))
return prunable
def module_sparsity(module: torch.nn.Module) -> float:
"""Compute fraction of zeros in a module's weight"""
if not hasattr(module, "weight"):
return 0.0
W = module.weight.data
return (W == 0).float().mean().item()
def model_sparsity_report(model: torch.nn.Module) -> dict:
"""Generate per-layer sparsity report"""
report = {}
total_params = 0
total_zeros = 0
for name, module in get_prunable_layers(model):
if not hasattr(module, "weight"):
continue
W = module.weight.data
zeros = (W == 0).sum().item()
total = W.numel()
report[name] = {
"sparsity": zeros / total,
"num_params": total,
"num_zeros": zeros,
}
total_params += total
total_zeros += zeros
report["_total"] = {
"sparsity": total_zeros / total_params if total_params > 0 else 0.0,
"num_params": total_params,
"num_zeros": total_zeros,
}
return report
이 함수들은 pruning 이후 "얼마나 많이 0이 되었는가"를 검증하는 데 쓰인다. 사용자나 테스트 코드가 "10% sparsity를 기대했는데 실제로 10.2%가 나왔다" 같은 sanity check를 할 수 있다.
sparsification_info/: 모듈별 상세 통계
@dataclass
class SparsificationConfig:
sparsity: float
structure: str # unstructured / 2:4 / block
mask_distribution: dict # 마스크 패턴 분석 결과
class ModuleSparsificationInfo:
"""Per-module sparsification statistics"""
def __init__(self, module):
self.module = module
self.weight = module.weight.data if hasattr(module, "weight") else None
self.num_params = self.weight.numel() if self.weight is not None else 0
self.num_zeros = (self.weight == 0).sum().item() if self.weight is not None else 0
@property
def sparsity(self) -> float:
return self.num_zeros / self.num_params if self.num_params > 0 else 0.0
@property
def density(self) -> float:
return 1.0 - self.sparsity
def is_2_4_compatible(self) -> bool:
"""Check if weight mask follows 2:4 pattern"""
if self.weight is None:
return False
# 4 열 묶음마다 정확히 2 개가 0 인지 검사
W = self.weight
rows, cols = W.shape
if cols % 4 != 0:
return False
W_reshaped = W.view(rows, cols // 4, 4)
zeros_per_group = (W_reshaped == 0).sum(dim=-1)
return (zeros_per_group == 2).all().item()
이 클래스는 "이 모듈이 진짜로 2:4 패턴인가"를 런타임에 검사할 수 있게 해준다. 2:4를 목표로 pruning했지만 실제로는 패턴이 깨졌다면 hardware acceleration을 못 받는다. 검증이 중요하다.
pytorch/model_load/helpers.py: 모델 로딩 헬퍼
def fallback_to_cpu_if_needed(model, target_device):
"""If target device has insufficient memory, load on CPU instead"""
try:
model.to(target_device)
except RuntimeError as e:
if "out of memory" in str(e):
logger.warning(f"OOM on {target_device}, falling back to CPU")
torch.cuda.empty_cache()
model.to("cpu")
else:
raise
def enable_quantized_forward(model):
"""Patch forward passes to emulate quantization during calibration"""
# compressed-tensors API 를 호출해 fake_quantize forward 활성화
...
모델 로딩 과정의 공통 실패 모드(OOM 등)를 우아하게 처리한다.
utils/dev.py와 utils/dist.py
# utils/dev.py
def get_main_device() -> torch.device:
"""Get the primary device for computation"""
if torch.cuda.is_available():
return torch.device("cuda:0")
return torch.device("cpu")
# utils/dist.py
def is_distributed() -> bool:
return dist.is_available() and dist.is_initialized()
def get_world_size() -> int:
return dist.get_world_size() if is_distributed() else 1
def get_rank() -> int:
return dist.get_rank() if is_distributed() else 0
분산 환경의 기본 유틸리티다. DDP 모드 지원을 위해 Quantization Base가 이 함수들을 호출해 rank/world_size를 확인한다.
왜 이 설계인가
1. 공통 패턴의 재사용. "모델 순회", "모듈 이름 매칭", "sparsity 측정" 같은 작업이 여러 Modifier에서 반복된다. 공용 헬퍼로 추출해 중복 제거.
2. infer_sequential_targets의 HF 속성 활용. _no_split_modules는 HF가 이미 정의한 정보다. 재발명하지 말고 이를 활용하는 것이 유지보수 측면에서 유리하다. 새 HF 모델이 나와도 _no_split_modules만 제대로 설정되어 있으면 llm-compressor가 자동 지원한다.
3. Sparsification 검증. pruning 결과를 자동 검증하는 유틸리티가 있어 "기대 sparsity"와 "실제 sparsity"의 차이를 빠르게 발견할 수 있다. CI/CD 통합에 유용.
4. 2:4 호환성 검사. 하드웨어 가속은 정확한 마스크 패턴에 의존한다. is_2_4_compatible가 런타임에 이를 검증해, 실수로 깨진 패턴이 체크포인트에 저장되는 것을 방지한다.
5. 분산 유틸리티 fallback. get_world_size가 분산 환경이 아니면 1을 반환하므로, Modifier 코드가 DDP와 단일 GPU를 같은 코드 경로로 처리할 수 있다.
마무리
PyTorch Utils는 프레임워크의 이음매(seam)다. 사용자 눈에는 안 보이지만 대부분의 Modifier가 의존한다. 마지막으로 Sentinel & Typing을 본다.
참고 자료
관련 포스트
llm-compressor 의 다른글
- 이전글 [llm-compressor] Dataset Calibration: c4/wikitext/ultrachat 로더
- 현재글 : [llm-compressor] PyTorch Utils: 희소화 통계와 모듈 헬퍼
- 다음글 [llm-compressor] Sentinel & Typing: 센티넬 객체와 타입 별칭
댓글