[pytorch] Inductor: bf16/fp16에서 addmm unfuse를 방지하여 정밀도 손실 해결
PR 링크: pytorch/pytorch#177144 상태: Merged | 변경: +22 / -0
들어가며
torch.addmm(bias, A, B)는 행렬 곱과 bias 덧셈을 하나의 커널에서 처리합니다. PyTorch Inductor의 pattern matcher는 성능 최적화를 위해 이를 mm + pointwise add로 분리(unfuse)하는 경우가 있습니다. 그러나 bf16/fp16처럼 정밀도가 낮은 데이터 타입에서는 이 분리가 문제를 일으킵니다. mm 결과가 half precision으로 한 번 더 truncate된 후 bias가 더해지면서, 딥 모델의 여러 레이어를 거치며 오차가 누적되기 때문입니다.
핵심 코드 분석
1. unfuse 방지 로직 (post_grad.py)
핵심 수정은 unfuse_bias_add_to_pointwise 함수 진입부에 dtype 체크를 추가한 것입니다.
Before:
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta):
def repl(inp, x1, x2, alpha, beta):
mm_result = x1 @ x2
# ... bias 덧셈 로직
After:
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta):
# Unfusing addmm introduces an extra bf16/fp16 truncation at the mm output
# that compounds through deep models and causes accuracy failures.
if inp.meta["val"].dtype in (torch.bfloat16, torch.float16):
return
def repl(inp, x1, x2, alpha, beta):
mm_result = x1 @ x2
# ... bias 덧셈 로직
inp.meta["val"].dtype을 통해 입력 텐서의 dtype을 확인하고, half precision인 경우 unfuse를 수행하지 않고 즉시 반환합니다. addmm이 fused 상태로 유지되면 내부적으로 fp32 accumulation을 사용할 수 있어 정밀도가 보존됩니다.
2. 테스트 추가
@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_unfuse_bias_addmm_half_dtypes(self, dtype):
args = [
torch.randn(20, device=GPU_TYPE, dtype=dtype),
torch.randn(10, 15, device=GPU_TYPE, dtype=dtype),
torch.randn(15, 20, device=GPU_TYPE, dtype=dtype),
]
@torch.compile()
def fn(inp, a, b):
return torch.nn.functional.gelu(torch.ops.aten.addmm(inp, a, b))
_, (code) = run_and_get_code(fn, args[0], args[1], args[2])
# addmm이 unfuse되지 않았음을 확인
FileCheck().check("extern_kernels.addmm(").run(code[0])
테스트는 addmm + gelu 조합에서 생성된 코드에 extern_kernels.addmm(이 그대로 남아있는지 확인합니다. unfuse되었다면 addmm 대신 mm이 나타날 것입니다.
왜 이게 좋은가
이 버그는 개별 연산 수준에서는 감지하기 어렵지만, LLM처럼 수십~수백 개의 레이어를 통과하면 accuracy failure로 이어집니다. 수정 자체는 3줄의 조건문이지만, 컴파일러 최적화가 수치 안정성에 미치는 영향을 정확히 이해해야 도달할 수 있는 수정입니다. fused addmm은 내부적으로 fp32 accumulator를 사용할 수 있지만, unfused mm의 출력은 half precision으로 저장되면서 truncation이 발생합니다. 이 차이가 딥 모델에서 누적되는 것이 근본 원인입니다.
정리
- bf16/fp16
addmm을mm + add로 분리하면 중간 결과의 truncation으로 정밀도 손실 발생 - pattern matcher에 dtype 체크 3줄 추가로 half precision에서의 unfuse 방지
- fused
addmm의 fp32 accumulation 이점을 보존
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Grafana Loki] 배치 처리를 파이프라인 래퍼로 분리하여 캐시 통합 준비
- 현재글 : [pytorch] Inductor: bf16/fp16에서 addmm unfuse를 방지하여 정밀도 손실 해결
- 다음글 [Ray] Ray Data에 cuDF 배치 포맷 추가
댓글