본문으로 건너뛰기

[PyTorch] FlexAttention에 저정밀도 K/V 입력 지원 추가

PR 링크: pytorch/pytorch#170486 상태: Merged | 변경: +375 / -24

들어가며

FlexAttention은 PyTorch의 커스텀 attention API로, score_mod를 통해 다양한 attention 패턴을 표현할 수 있다. 그러나 기존에는 Q, K, V의 dtype이 동일해야 한다는 제약이 있었다. LLM 양자화 추론에서 K/V를 FP8로 저장하고 Q만 BF16/FP16으로 유지하는 것은 흔한 패턴인데, 이를 지원하지 못했다. 이 PR은 compiled 모드에서 저정밀도 K/V 입력을 허용한다.

핵심 코드 분석

dtype 검증 완화

Before:

def _validate_sdpa_input(query, key, value, ...):
    if query.dtype != key.dtype or query.dtype != value.dtype:
        raise ValueError(
            f"Expected query, key, and value to have the same dtype, "
            f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
            f"and value.dtype: {value.dtype} instead."
        )

After:

def _validate_sdpa_input(query, key, value, ..., allow_lowp_kv=False):
    if not allow_lowp_kv:
        if query.dtype != key.dtype or query.dtype != value.dtype:
            raise ValueError(
                f"Expected query, key, and value to have the same dtype, "
                f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
                f"and value.dtype: {value.dtype} instead."
            )

allow_lowp_kv=True 플래그를 통해 dtype 불일치를 허용한다. FlexAttention은 이 플래그를 활성화하여 호출한다.

Math reference 구현 수정

Before:

return (
    post_mod_scores.to(query.dtype) @ value,
    logsumexp / math.log(2),
    max_scores / math.log(2),
)

After:

return (
    post_mod_scores.to(query.dtype) @ value.to(query.dtype),
    logsumexp / math.log(2),
    max_scores / math.log(2),
)

Math fallback 경로에서 value를 query dtype으로 캐스팅하여, FP8 value와의 행렬곱에서 정확도 문제를 방지한다.

Backward pass 제한

def flex_attention_backward(*args, **kwargs):
    if query.dtype != key.dtype or query.dtype != value.dtype:
        raise ValueError(
            f"Backward pass with mixed query, key, and value dtype "
            f"is not supported"
        )

Forward는 mixed dtype을 허용하지만, backward는 아직 지원하지 않는다. 이는 gradient 계산에서 dtype 변환의 수치적 안정성이 검증되지 않았기 때문이다.

왜 이게 좋은가

FP8 K/V 양자화 + score_mod 기반 역양자화 패턴이 가능해진다. Per-tensor 또는 per-head 스케일링으로 K/V를 FP8로 저장하면 메모리 사용량이 절반으로 줄어든다. 테스트에서 BF16 대비 SQNR 10dB 이상을 유지하면서도 KV cache 크기를 50% 줄일 수 있음을 검증했다.

def score_mod(score, b, h, m, n):
    return score * key_scale[b, h]  # Per-head dequantization

compiled_fn = torch.compile(flex_attention, fullgraph=True)
out = compiled_fn(query_bf16, key_fp8, value_fp8, score_mod) * value_scale

정리

  • FlexAttention에서 Q와 K/V의 dtype이 다른 입력을 허용한다
  • GPU compiled 모드에서 Triton 커널이 자동으로 dtype 변환을 처리한다
  • Backward pass는 mixed dtype을 지원하지 않으며, 명확한 에러 메시지를 제공한다
  • Per-tensor/per-head FP8 양자화 패턴에 대한 SQNR 테스트가 추가되었다

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글