본문으로 건너뛰기

[vLLM] KV Cache Quantization: KV 캐시 FP8/INT8 양자화

들어가며

KV 캐시는 긴 시퀀스에서 GPU 메모리의 주요 병목이다. FP16/BF16 대신 FP8이나 INT8로 KV 캐시를 양자화하면 메모리 사용량을 절반으로 줄일 수 있다. vLLM은 vllm/model_executor/layers/quantization/kv_cache.py에서 이 양자화 로직을 구현한다.

공식 문서

vLLM 공식 문서: Quantized KV Cache

핵심 구조/코드 분석

BaseKVCacheMethod

class BaseKVCacheMethod(QuantizeMethodBase):
    def __init__(self, quant_config: QuantizationConfig):
        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module):
        layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
        layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
        layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
        layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)

초기값 -1.0은 "아직 로드되지 않은" 상태를 나타낸다. 체크포인트에 스케일이 있으면 로딩 시 덮어쓰이고, 없으면 이후 처리에서 기본값(1.0)이 설정된다. Q, K, V, prob 네 가지 스케일을 모두 관리한다.

Per-Token-Head vs Per-Tensor 양자화

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    # Per-token-head: 커널이 동적으로 스케일 계산
    if kv_cache_uses_per_token_head_scales(layer.kv_cache_dtype):
        layer._k_scale.copy_(1.0)
        layer._v_scale.copy_(1.0)
        del layer.k_scale, layer.v_scale, layer.q_scale, layer.prob_scale
        return

    # Per-tensor: 체크포인트의 스케일 사용
    if is_quantized_kv_cache(layer.kv_cache_dtype) and not layer.calculate_kv_scales:
        if layer.k_scale > 0.0 and layer.v_scale > 0.0:
            k_scale = layer.k_scale.to("cpu").tolist()
            v_scale = layer.v_scale.to("cpu").tolist()
        elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
            k_scale = 1.0  # 기본값
            v_scale = 1.0

두 가지 양자화 모드가 있다:

  • Per-token-head: 각 토큰, 각 어텐션 헤드마다 별도의 스케일을 동적으로 계산한다. 체크포인트 스케일이 필요 없고, 커널이 캐시 저장 시점에 자동으로 계산한다.
  • Per-tensor: 전체 레이어에 하나의 K/V 스케일을 사용한다. 체크포인트에서 로드하거나 기본값 1.0을 사용한다.

AMD FP8 FNUZ 처리

if current_platform.is_fp8_fnuz():
    k_scale *= 2
    v_scale *= 2

AMD GPU에서 사용하는 FP8 FNUZ(Finite, Not Undefined, No Zero) 형식은 표현 범위가 다르기 때문에, 스케일을 2배로 조정하는 플랫폼 특화 처리가 있다.

Q/Prob 스케일과 FP8 Attention

if layer.q_scale > 0.0:
    q_scale = layer.q_scale
    layer.calculate_kv_scales = False
else:
    q_scale = 1.0

if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 or prob_scale == 1.0):
    logger.warning_once(
        f"Using uncalibrated q_scale {q_scale} and/or prob_scale "
        f"{prob_scale} with fp8 attention..."
    )

FP8 Attention(FlashAttention, FlashInfer)을 사용할 때는 Q 텐서와 softmax 확률(prob)에 대한 스케일도 필요하다. 캘리브레이션되지 않은 스케일(1.0)을 사용하면 정확도 문제가 발생할 수 있다고 경고한다.

스케일 통합 로직

# 단일 kv_scale만 있는 경우 -> k_scale에 매핑 후 v_scale에 복제
assert layer.k_scale > 0.0
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()

일부 체크포인트는 단일 kv_scale만 제공한다. 이 경우 가중치 로딩 시 k_scale에 매핑하고, 이후 처리에서 k_scalev_scale에 동일한 값을 복제한다.

왜 이 설계인가

  1. -1.0 센티널 값: 파라미터의 유무를 별도 플래그 없이 값 자체로 판단한다. 양의 스케일은 유효한 체크포인트 값, 음의 스케일은 미로드 상태를 뜻한다. 이는 다양한 체크포인트 형식(k_scale+v_scale, 단일 kv_scale, 스케일 없음)을 통합적으로 처리하기 위한 전략이다.

  2. Per-token-head의 동적 계산: 정적 per-tensor 스케일은 체크포인트에 종속적이어서, 캘리브레이션 없이 사용하면 정확도가 떨어진다. Per-token-head 방식은 런타임에 최적 스케일을 계산하므로 범용적으로 사용할 수 있다.

  3. process_weights_after_loading 패턴: vLLM은 가중치 로딩과 후처리를 분리한다. 로딩 시에는 체크포인트 값을 그대로 읽고, 후처리에서 플랫폼별 변환(FNUZ), 스케일 검증, 기본값 할당을 수행한다. 이 패턴은 다양한 양자화 방식을 통합하는 데 효과적이다.

참고 자료

댓글

관련 포스트

vLLM 의 다른글