본문으로 건너뛰기

[pytorch] Inductor: bf16/fp16에서 addmm unfuse를 방지하여 정밀도 손실 해결

PR 링크: pytorch/pytorch#177144 상태: Merged | 변경: +22 / -0

들어가며

torch.addmm(bias, A, B)는 행렬 곱과 bias 덧셈을 하나의 커널에서 처리합니다. PyTorch Inductor의 pattern matcher는 성능 최적화를 위해 이를 mm + pointwise add로 분리(unfuse)하는 경우가 있습니다. 그러나 bf16/fp16처럼 정밀도가 낮은 데이터 타입에서는 이 분리가 문제를 일으킵니다. mm 결과가 half precision으로 한 번 더 truncate된 후 bias가 더해지면서, 딥 모델의 여러 레이어를 거치며 오차가 누적되기 때문입니다.

핵심 코드 분석

1. unfuse 방지 로직 (post_grad.py)

핵심 수정은 unfuse_bias_add_to_pointwise 함수 진입부에 dtype 체크를 추가한 것입니다.

Before:

def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta):
    def repl(inp, x1, x2, alpha, beta):
        mm_result = x1 @ x2
        # ... bias 덧셈 로직

After:

def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta):
    # Unfusing addmm introduces an extra bf16/fp16 truncation at the mm output
    # that compounds through deep models and causes accuracy failures.
    if inp.meta["val"].dtype in (torch.bfloat16, torch.float16):
        return

    def repl(inp, x1, x2, alpha, beta):
        mm_result = x1 @ x2
        # ... bias 덧셈 로직

inp.meta["val"].dtype을 통해 입력 텐서의 dtype을 확인하고, half precision인 경우 unfuse를 수행하지 않고 즉시 반환합니다. addmm이 fused 상태로 유지되면 내부적으로 fp32 accumulation을 사용할 수 있어 정밀도가 보존됩니다.

2. 테스트 추가

@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_unfuse_bias_addmm_half_dtypes(self, dtype):
    args = [
        torch.randn(20, device=GPU_TYPE, dtype=dtype),
        torch.randn(10, 15, device=GPU_TYPE, dtype=dtype),
        torch.randn(15, 20, device=GPU_TYPE, dtype=dtype),
    ]

    @torch.compile()
    def fn(inp, a, b):
        return torch.nn.functional.gelu(torch.ops.aten.addmm(inp, a, b))

    _, (code) = run_and_get_code(fn, args[0], args[1], args[2])
    # addmm이 unfuse되지 않았음을 확인
    FileCheck().check("extern_kernels.addmm(").run(code[0])

테스트는 addmm + gelu 조합에서 생성된 코드에 extern_kernels.addmm(이 그대로 남아있는지 확인합니다. unfuse되었다면 addmm 대신 mm이 나타날 것입니다.

왜 이게 좋은가

이 버그는 개별 연산 수준에서는 감지하기 어렵지만, LLM처럼 수십~수백 개의 레이어를 통과하면 accuracy failure로 이어집니다. 수정 자체는 3줄의 조건문이지만, 컴파일러 최적화가 수치 안정성에 미치는 영향을 정확히 이해해야 도달할 수 있는 수정입니다. fused addmm은 내부적으로 fp32 accumulator를 사용할 수 있지만, unfused mm의 출력은 half precision으로 저장되면서 truncation이 발생합니다. 이 차이가 딥 모델에서 누적되는 것이 근본 원인입니다.

정리

  • bf16/fp16 addmmmm + add로 분리하면 중간 결과의 truncation으로 정밀도 손실 발생
  • pattern matcher에 dtype 체크 3줄 추가로 half precision에서의 unfuse 방지
  • fused addmm의 fp32 accumulation 이점을 보존

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글