[PyTorch] Inductor MPS Metal 셰이더 half-precision 타입 불일치 수정
PR 링크: pytorch/pytorch#177193 상태: Merged | 변경: +7 / -3
들어가며
PyTorch Inductor의 MPS 백엔드는 Metal 셰이더 코드를 자동 생성한다. half (float16) 타입 텐서에 masked 또는 where 연산을 적용할 때, 생성된 Metal 코드에서 타입 불일치가 발생하여 컴파일 에러가 나는 버그가 있었다. 원인은 조건문의 결과값과 대체값의 타입이 서로 달라 Metal 컴파일러가 암시적 변환을 거부하는 것이었다.
핵심 코드 분석
masked 연산 수정
Before:
@staticmethod
def masked(mask, body, other):
# ...
V.kernel.compute.writeline(f"{var} = {rc};")
V.kernel.compute.writeline(f"}} else {var} = {other_str};")
return var
After:
@staticmethod
def masked(mask, body, other):
# ...
V.kernel.compute.writeline(
f"{var} = static_cast<decltype({var})>({rc});"
)
V.kernel.compute.writeline(
f"}} else {var} = static_cast<decltype({var})>({other_str});"
)
return var
static_cast<decltype({var})>를 사용하여 결과값과 대체값 모두를 변수의 선언 타입으로 명시적 캐스팅한다. decltype은 컴파일 타임에 변수 타입을 추론하므로, 어떤 타입 조합에서도 안전하게 동작한다.
where 연산 수정
Before:
@staticmethod
def where(a, b, c):
return f"{a} ? {b} : {value_to_metal(c)}"
After:
@staticmethod
def where(a, b, c):
return f"{a} ? {b} : static_cast<decltype({b})>({value_to_metal(c)})"
삼항 연산자에서 c 값(보통 0.0이나 상수)이 float로 추론되는 반면 b가 half인 경우, Metal은 half와 float 간 암시적 변환을 허용하지 않는다. static_cast로 b의 타입에 맞춘다.
왜 이게 좋은가
이 버그는 half precision 텐서에 마스킹 연산을 할 때 발생한다. 예를 들어 attention mask 적용이나 dropout에서 where(mask, x, 0.0)을 사용하면, 0.0이 float로 추론되어 half 타입의 x와 충돌한다. 수정이 없으면 Metal 셰이더 컴파일이 실패하여 해당 모델을 MPS에서 실행할 수 없다.
단 4줄의 변경(+7/-3)이지만, MPS에서 half precision 모델의 실행 가능성을 복원하는 중요한 수정이다.
정리
masked와where연산의 Metal codegen에static_cast<decltype(...)>를 추가했다half타입에서 발생하던 Metal 셰이더 컴파일 에러가 해결된다decltype기반 캐스팅으로 타입 안전성을 유지하면서 모든 정밀도를 지원한다- 최소한의 변경으로 최대 영향을 미치는 전형적인 codegen 버그 수정이다
참고 자료
- 이슈 #176436 -- 원본 버그 리포트
- Metal Shading Language -- Metal의 타입 변환 규칙
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Loki] Helm 차트 Memcached CPU 리소스 오버라이드 지원 추가
- 현재글 : [PyTorch] Inductor MPS Metal 셰이더 half-precision 타입 불일치 수정
- 다음글 [Grafana Loki] 배치 처리를 파이프라인 래퍼로 분리하여 캐시 통합 준비
댓글