본문으로 건너뛰기

[pytorch] MPS: 2-pass SDPA의 메모리 손상을 float accumulator 강제로 수정

PR 링크: pytorch/pytorch#175580 상태: Merged | 변경: +22 / -4

들어가며

Apple Silicon의 MPS(Metal Performance Shaders) 백엔드에서 Scaled Dot-Product Attention(SDPA)을 실행할 때, 시퀀스 길이가 긴 경우 2-pass 알고리즘이 사용됩니다. 이 알고리즘은 블록 단위로 attention을 계산한 후 결과를 합산하는데, 합산에 사용되는 중간 버퍼(sums, maxs)가 입력과 동일한 half precision(bf16/fp16)으로 생성되면서 오버플로우와 메모리 손상이 발생하는 버그가 보고되었습니다.

핵심 코드 분석

1. 중간 버퍼의 dtype 강제 (Attention.mm)

핵심 수정은 단 2줄입니다.

Before:

auto sums = at::empty({batchSize, num_heads, seq_len_q, blocks}, q_.options());
auto maxs = at::empty({batchSize, num_heads, seq_len_q, blocks}, q_.options());

After:

auto sums = at::empty({batchSize, num_heads, seq_len_q, blocks}, q_.options().dtype(kFloat));
auto maxs = at::empty({batchSize, num_heads, seq_len_q, blocks}, q_.options().dtype(kFloat));

q_.options()는 입력 query 텐서의 device, dtype 등을 복사하는데, 입력이 bf16/fp16이면 중간 버퍼도 동일한 half precision으로 생성됩니다. .dtype(kFloat)를 추가하여 중간 버퍼를 항상 float32로 생성하도록 변경했습니다.

2. 회귀 테스트 추가

@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float])
def test_sdpa_2pass(self, dtype):
    # Regression test for https://github.com/pytorch/pytorch/issues/174861
    q = torch.randn(1, 32, 1, 128, dtype=dtype)
    k = torch.randn(1, 2, 1024, 128, dtype=dtype)
    v = torch.randn(1, 2, 1024, 128, dtype=dtype)
    sdpa_kwargs = {"enable_gqa": True}

    out_cpu = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs)
    out_mps = F.scaled_dot_product_attention(
        q.to("mps"), k.to("mps"), v.to("mps"), **sdpa_kwargs
    )
    tol = 0.1 if dtype == torch.bfloat16 else 0.01
    self._compare_tensors(out_mps.cpu(), out_cpu, tol=tol)

테스트는 num_heads=32, kv_heads=2의 GQA(Grouped Query Attention) 구성에서 seq_len_kv=1024로 2-pass 경로를 트리거합니다. CPU와 MPS 결과를 비교하여 정확성을 검증합니다.

왜 이게 좋은가

2-pass attention 알고리즘에서 sums는 softmax의 분모(지수합), maxs는 수치 안정성을 위한 최대값을 저장합니다. 이 값들이 half precision으로 저장되면 블록 간 합산 시 오버플로우가 발생할 수 있고, 이는 단순한 오차가 아니라 메모리 손상으로 이어집니다. float32 accumulator는 attention 구현의 표준적인 모범 사례이며, CUDA 백엔드의 FlashAttention도 동일한 전략을 사용합니다. 수정 코드 자체는 2줄이지만, 문제의 근본 원인을 정확히 파악한 결과입니다.

정리

  • MPS 2-pass SDPA에서 half precision 중간 버퍼로 인한 메모리 손상 버그 수정
  • sumsmaxs 버퍼를 .dtype(kFloat)로 float32 강제 전환
  • GQA 구성의 회귀 테스트 추가 (bf16, fp16, float32)

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글