본문으로 건너뛰기

[llm-compressor] SpinQuant: 학습된 회전 행렬 기반 양자화

들어가며

SpinQuant는 Meta가 2024년에 발표한 논문으로, QuIP의 랜덤 회전 아이디어를 한 단계 발전시켰다. QuIP는 랜덤 Hadamard를 쓰지만, SpinQuant는 "회전을 학습해서 최적 회전을 찾는다". Cayley SGD(Cayley 매니폴드 기반 최적화)로 직교 행렬을 학습하며, LLaMA-3 같은 최신 모델에서 QuIP#보다 더 좋은 결과를 보인다. SpinQuant 논문과 llm-compressor의 src/llmcompressor/modifiers/transform/spinquant/base.py를 분석한다.

공식 문서

논문 핵심 내용

SpinQuant는 네 종류의 회전을 제안한다.

  1. R1 (global online rotation): 모델 전체에 걸친 단일 회전. residual stream에 적용.
  2. R2 (attention-specific): 각 attention head의 내부 회전.
  3. R3 (MLP online rotation): MLP 블록의 내부 회전.
  4. R4 (per-layer online rotation): 각 Linear 레이어마다 독립적인 회전 (Hadamard).

R1과 R2는 학습 가능하다. Cayley SGD로 직교 제약을 유지하면서 학습한다. R3와 R4는 고정된 Hadamard를 쓴다 (QuIP처럼).

Cayley SGD는 직교 매니폴드에서 경사 하강을 수행하는 방법이다. 경사 $G$를 직교 접선 공간으로 사영한 뒤, Cayley 변환으로 업데이트된 직교 행렬을 얻는다.

$$ R_{t+1} = R_t \cdot \text{Cayley}(-\eta \hat{G}) $$

여기서 $\hat{G}$는 $G$의 skew-symmetric 부분이고, Cayley 변환은 $\text{Cayley}(A) = (I - A)(I + A)^{-1}$로 정의된다. 이 변환의 결과는 항상 직교 행렬이다.

벤치마크 (논문 기준)

방법 LLaMA-2 7B W4A4 PPL
QuIP# 8.91
SpinQuant 6.98
FP16 reference 5.47

SpinQuant가 W4A4(가중치 4비트 + 활성화 4비트)에서 QuIP#보다 유의미하게 좋다. 학습된 회전이 랜덤 회전보다 더 "이 모델에 맞는" 변환을 찾기 때문이다.

핵심 구조/코드 분석

SpinQuantModifier 파라미터

class SpinQuantModifier(Modifier):
    """Implements SpinQuant from https://arxiv.org/abs/2405.16406"""

    rotations: list[str] = field(default_factory=lambda: ["R1", "R2", "R4"])
    learnable: list[str] = field(default_factory=lambda: ["R1", "R2"])
    lr: float = 1e-4                      # Cayley SGD 학습률
    iters: int = 200                      # 학습 반복 횟수
    mappings: list[SpinQuantMapping] | None = None
    norm_mappings: list[NormMapping] | None = None
파라미터 기본값 의미
rotations ["R1", "R2", "R4"] 적용할 회전 타입. R3는 선택적
learnable ["R1", "R2"] 학습할 회전. 나머지는 고정
lr 1e-4 Cayley SGD 학습률
iters 200 경사 하강 반복 수
mappings None 아키텍처별 변환 지점 (자동 추론 권장)
norm_mappings None RMSNorm 흡수 위치

SpinQuantMappingNormMapping

# mappings.py
@dataclass
class SpinQuantMapping:
    name: str                      # "R1" | "R2" | "R3" | "R4"
    target_layers: list[str]       # 회전이 곱해질 Linear 목록
    previous_layers: list[str]     # 회전 역원이 흡수될 레이어 (RMSNorm 등)

LLAMA_SPIN_MAPPINGS = [
    SpinQuantMapping(
        name="R1",
        target_layers=[
            "model.embed_tokens",
            "lm_head",
            ... # residual 에 연결된 모든 레이어
        ],
        previous_layers=[],
    ),
    SpinQuantMapping(
        name="R2",
        target_layers=["self_attn.o_proj"],
        previous_layers=["self_attn.v_proj"],
    ),
    ...
]
# norm_mappings.py
@dataclass
class NormMapping:
    norm_layer: str                # "input_layernorm", "post_attention_layernorm"
    following_layers: list[str]    # 이 norm 뒤에 오는 레이어들

LLAMA_NORM_MAPPINGS = [
    NormMapping(
        norm_layer="input_layernorm",
        following_layers=["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
    ),
    NormMapping(
        norm_layer="post_attention_layernorm",
        following_layers=["mlp.gate_proj", "mlp.up_proj"],
    ),
]

이 두 매핑이 SpinQuant의 아키텍처 의존성을 캡슐화한다. 새 모델을 지원하려면 두 매핑을 추가하면 된다.

on_initialize: 회전 생성과 적용

def on_initialize(self, state: State, **kwargs) -> bool:
    model = state.model
    hidden_size = model.config.hidden_size

    # 1) R1 생성 (global)
    if "R1" in self.rotations:
        R1 = self._init_rotation(hidden_size)   # 랜덤 직교 행렬
        if "R1" in self.learnable:
            R1 = self._train_rotation(R1, model, scope="residual")
        self._apply_rotation(model, "R1", R1)

    # 2) R2 생성 (per attention head)
    if "R2" in self.rotations:
        head_dim = hidden_size // model.config.num_attention_heads
        R2 = self._init_rotation(head_dim)
        if "R2" in self.learnable:
            R2 = self._train_rotation(R2, model, scope="attention")
        self._apply_rotation(model, "R2", R2)

    # 3) R4 생성 (per-layer Hadamard)
    if "R4" in self.rotations:
        for name, module in self._get_r4_targets(model):
            R4 = self._make_hadamard(module.in_features)
            self._apply_rotation_to_module(module, R4)

    return True

R1과 R2는 학습되고, R4는 고정 Hadamard다. 학습은 _train_rotation에서 일어난다.

_train_rotation: Cayley SGD

def _train_rotation(self, R_init, model, scope):
    R = R_init.clone().requires_grad_(True)
    optimizer = CayleySGD([R], lr=self.lr)

    # 캘리브레이션 데이터로 양자화 손실 측정
    for step in range(self.iters):
        # (R 적용 전 forward) - (R 적용 후 fake_quantize forward) 오차 계산
        loss = self._measure_rotation_quant_loss(R, model, scope)
        loss.backward()

        # Cayley SGD update: R ← R · Cayley(-lr · skew(G))
        optimizer.step()
        optimizer.zero_grad()

    return R.detach()


class CayleySGD(torch.optim.Optimizer):
    def step(self):
        for group in self.param_groups:
            lr = group["lr"]
            for R in group["params"]:
                if R.grad is None:
                    continue
                G = R.grad

                # Skew-symmetric projection: A = G @ R^T - R @ G^T
                A = G @ R.T - R @ G.T

                # Cayley transform: R ← R · (I - lr/2 * A)^-1 · (I + lr/2 * A)
                n = R.shape[0]
                I = torch.eye(n, device=R.device, dtype=R.dtype)
                update = torch.linalg.solve(I + lr/2 * A, I - lr/2 * A)
                R.data = R.data @ update

Cayley SGD의 핵심은 직교 제약을 자동으로 유지한다는 점이다. 일반 SGD로 $R$을 업데이트하면 직교성이 깨지지만, Cayley 변환은 수학적으로 "어떤 $A$에 대해서도 $\text{Cayley}(A)$가 직교"를 보장한다. 따라서 매 스텝 후 R은 항상 정확히 직교 행렬이다.

_apply_rotation: RMSNorm 흡수

def _apply_rotation(self, model, rotation_name, R):
    """
    Apply rotation by absorbing it into weights and RMSNorm.
    Mathematical identity: y = W R R^T x = (WR)(R^T x)
    """
    mapping = self._get_mapping(rotation_name)

    # W ← W R  (target layers 의 입력 차원에 R 곱)
    for layer_name in mapping.target_layers:
        module = model.get_submodule(layer_name)
        module.weight.data = module.weight.data @ R

    # RMSNorm weight 를 R 로 "회전"
    # 수학적 트릭: RMSNorm(R^T x) 은 일반적으로 R^T RMSNorm(x) 이 아니지만,
    # RMSNorm 의 weight 가 scalar-per-channel 이라 원소 곱으로 흡수 가능
    for norm_map in self.norm_mappings:
        if any(layer in mapping.target_layers for layer in norm_map.following_layers):
            norm_module = model.get_submodule(norm_map.norm_layer)
            norm_module.weight.data = (R.T @ norm_module.weight.data.unsqueeze(-1)).squeeze()

RMSNorm과 직교 회전의 비교환성이 SpinQuant의 기술적 난점이다. 정확히는 $\text{RMSNorm}(R^T x) \neq R^T \text{RMSNorm}(x)$이지만, 특정 조건 하에서 RMSNorm의 per-channel weight를 회전에 흡수하면 결과가 같아진다. 이 흡수가 norm_mappings의 역할이다.

왜 이 설계인가

1. 학습된 회전 > 랜덤 회전. QuIP의 랜덤 Hadamard는 "일반적으로 좋은" 변환이다. SpinQuant는 "이 모델에 특화된" 변환을 학습한다. W4A4 같은 극단 시나리오에서 차이가 크다.

2. Cayley SGD의 직교 보장. 일반 SGD는 직교성을 잃지만, Cayley 변환은 수학적으로 직교를 유지한다. 실험 중 $R$을 다시 직교화할 필요가 없어 안정적이다.

3. 4가지 회전의 계층화. R1(global) → R2(attention) → R3(MLP) → R4(per-layer). 각 회전이 다른 스케일의 incoherence를 처리한다. 필요에 따라 일부만 활성화할 수 있다.

4. norm_mappings 분리. RMSNorm 흡수는 미묘한 수학적 문제다. 아키텍처별로 정확한 흡수 지점을 norm_mappings.py에 하드코딩해 오류를 방지한다.

5. 학습 비용 관리. iters=200은 QuantizationModifier 자체의 캘리브레이션보다는 길지만, GPTQ의 헤시안 계산과 비슷한 수준이다. 결과 품질 차이를 고려하면 수용 가능하다.

마무리

SpinQuant는 llm-compressor의 가장 정교한 Transform Modifier다. W4A4 같은 극단 시나리오에 최적이다. 다음 글은 마지막 Transform 구현인 iMatrix Transform을 본다.

참고 자료

댓글

관련 포스트

llm-compressor 의 다른글