[PyTorch] FlexAttention에 저정밀도 K/V 입력 지원 추가
PR 링크: pytorch/pytorch#170486 상태: Merged | 변경: +375 / -24
들어가며
FlexAttention은 PyTorch의 커스텀 attention API로, score_mod를 통해 다양한 attention 패턴을 표현할 수 있다. 그러나 기존에는 Q, K, V의 dtype이 동일해야 한다는 제약이 있었다. LLM 양자화 추론에서 K/V를 FP8로 저장하고 Q만 BF16/FP16으로 유지하는 것은 흔한 패턴인데, 이를 지원하지 못했다. 이 PR은 compiled 모드에서 저정밀도 K/V 입력을 허용한다.
핵심 코드 분석
dtype 검증 완화
Before:
def _validate_sdpa_input(query, key, value, ...):
if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError(
f"Expected query, key, and value to have the same dtype, "
f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
f"and value.dtype: {value.dtype} instead."
)
After:
def _validate_sdpa_input(query, key, value, ..., allow_lowp_kv=False):
if not allow_lowp_kv:
if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError(
f"Expected query, key, and value to have the same dtype, "
f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
f"and value.dtype: {value.dtype} instead."
)
allow_lowp_kv=True 플래그를 통해 dtype 불일치를 허용한다. FlexAttention은 이 플래그를 활성화하여 호출한다.
Math reference 구현 수정
Before:
return (
post_mod_scores.to(query.dtype) @ value,
logsumexp / math.log(2),
max_scores / math.log(2),
)
After:
return (
post_mod_scores.to(query.dtype) @ value.to(query.dtype),
logsumexp / math.log(2),
max_scores / math.log(2),
)
Math fallback 경로에서 value를 query dtype으로 캐스팅하여, FP8 value와의 행렬곱에서 정확도 문제를 방지한다.
Backward pass 제한
def flex_attention_backward(*args, **kwargs):
if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError(
f"Backward pass with mixed query, key, and value dtype "
f"is not supported"
)
Forward는 mixed dtype을 허용하지만, backward는 아직 지원하지 않는다. 이는 gradient 계산에서 dtype 변환의 수치적 안정성이 검증되지 않았기 때문이다.
왜 이게 좋은가
FP8 K/V 양자화 + score_mod 기반 역양자화 패턴이 가능해진다. Per-tensor 또는 per-head 스케일링으로 K/V를 FP8로 저장하면 메모리 사용량이 절반으로 줄어든다. 테스트에서 BF16 대비 SQNR 10dB 이상을 유지하면서도 KV cache 크기를 50% 줄일 수 있음을 검증했다.
def score_mod(score, b, h, m, n):
return score * key_scale[b, h] # Per-head dequantization
compiled_fn = torch.compile(flex_attention, fullgraph=True)
out = compiled_fn(query_bf16, key_fp8, value_fp8, score_mod) * value_scale
정리
- FlexAttention에서 Q와 K/V의 dtype이 다른 입력을 허용한다
- GPU compiled 모드에서 Triton 커널이 자동으로 dtype 변환을 처리한다
- Backward pass는 mixed dtype을 지원하지 않으며, 명확한 에러 메시지를 제공한다
- Per-tensor/per-head FP8 양자화 패턴에 대한 SQNR 테스트가 추가되었다
참고 자료
- FlexAttention 블로그 -- FlexAttention API 소개
- FP8 Training/Inference -- FP8 수치 포맷 활용
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] Proton의 Runtime과 Metric 상관관계 단순화로 오버헤드 감소
- 현재글 : [PyTorch] FlexAttention에 저정밀도 K/V 입력 지원 추가
- 다음글 [pydantic-ai] DBOS 테스트용 인메모리 SQLite 되돌리기: 파일 기반 DB 복원
댓글