본문으로 건너뛰기

[PyTorch] Inductor MPS Metal 셰이더 half-precision 타입 불일치 수정

PR 링크: pytorch/pytorch#177193 상태: Merged | 변경: +7 / -3

들어가며

PyTorch Inductor의 MPS 백엔드는 Metal 셰이더 코드를 자동 생성한다. half (float16) 타입 텐서에 masked 또는 where 연산을 적용할 때, 생성된 Metal 코드에서 타입 불일치가 발생하여 컴파일 에러가 나는 버그가 있었다. 원인은 조건문의 결과값과 대체값의 타입이 서로 달라 Metal 컴파일러가 암시적 변환을 거부하는 것이었다.

핵심 코드 분석

masked 연산 수정

Before:

@staticmethod
def masked(mask, body, other):
    # ...
    V.kernel.compute.writeline(f"{var} = {rc};")
    V.kernel.compute.writeline(f"}} else {var} = {other_str};")
    return var

After:

@staticmethod
def masked(mask, body, other):
    # ...
    V.kernel.compute.writeline(
        f"{var} = static_cast<decltype({var})>({rc});"
    )
    V.kernel.compute.writeline(
        f"}} else {var} = static_cast<decltype({var})>({other_str});"
    )
    return var

static_cast<decltype({var})>를 사용하여 결과값과 대체값 모두를 변수의 선언 타입으로 명시적 캐스팅한다. decltype은 컴파일 타임에 변수 타입을 추론하므로, 어떤 타입 조합에서도 안전하게 동작한다.

where 연산 수정

Before:

@staticmethod
def where(a, b, c):
    return f"{a} ? {b} : {value_to_metal(c)}"

After:

@staticmethod
def where(a, b, c):
    return f"{a} ? {b} : static_cast<decltype({b})>({value_to_metal(c)})"

삼항 연산자에서 c 값(보통 0.0이나 상수)이 float로 추론되는 반면 bhalf인 경우, Metal은 halffloat 간 암시적 변환을 허용하지 않는다. static_castb의 타입에 맞춘다.

왜 이게 좋은가

이 버그는 half precision 텐서에 마스킹 연산을 할 때 발생한다. 예를 들어 attention mask 적용이나 dropout에서 where(mask, x, 0.0)을 사용하면, 0.0float로 추론되어 half 타입의 x와 충돌한다. 수정이 없으면 Metal 셰이더 컴파일이 실패하여 해당 모델을 MPS에서 실행할 수 없다.

단 4줄의 변경(+7/-3)이지만, MPS에서 half precision 모델의 실행 가능성을 복원하는 중요한 수정이다.

정리

  • maskedwhere 연산의 Metal codegen에 static_cast<decltype(...)>를 추가했다
  • half 타입에서 발생하던 Metal 셰이더 컴파일 에러가 해결된다
  • decltype 기반 캐스팅으로 타입 안전성을 유지하면서 모든 정밀도를 지원한다
  • 최소한의 변경으로 최대 영향을 미치는 전형적인 codegen 버그 수정이다

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글