본문으로 건너뛰기

[PyTorch] MPS mul 성능 회귀 수정

PR 링크: pytorch/pytorch#172106 상태: Merged | 변경: +635 / -57

들어가며

PyTorch의 MPS(Metal Performance Shaders) 백엔드에서 mul 연산의 성능 회귀가 발생했다. 원인은 broadcasting 상황에서 strided iterator를 사용하면서 불필요한 오버헤드가 생긴 것이다. 이 PR은 broadcast와 scalar 케이스를 감지하여 전용 Metal 셰이더 커널로 분기하는 최적화를 추가한다.

핵심 코드 분석

Broadcast 감지 로직

static bool is_dense_broadcastable(const Tensor& from, const Tensor& into) {
  if (!from.is_contiguous() || !into.is_contiguous()) {
    return false;
  }
  bool checking_squeezable_dims = false;
  for (const auto dim : c10::irange(from.ndimension())) {
    if (checking_squeezable_dims) {
      if (from.size(-dim - 1) == 1) continue;
      return false;
    }
    checking_squeezable_dims = from.size(-dim - 1) != into.size(-dim - 1);
  }
  return true;
}

두 텐서가 contiguous이고, 앞쪽 차원만 1인 경우(예: [1, 1, 64][8, 8, 64]에 broadcast)를 감지한다. 이 조건이면 tid % broadcast_numel로 간단히 인덱싱할 수 있다.

Kernel 분기

Before:

const auto kernel_name = cast_needed
    ? fmt::format("{}_{}_cast_{}{}", name, suffix, scalarToMetalTypeString(out), alpha_suffix)
    : fmt::format("{}_{}_{}_{}{}", name, suffix, scalarToMetalTypeString(out),
                  scalarToMetalTypeString(input), alpha_suffix);

After:

if (use_scalar_kernel) {
    kernel_name = fmt::format("{}_dense_scalar{}_{}_{}{}",
        name, lhs_suffix, scalarToMetalTypeString(out),
        scalarToMetalTypeString(tensor_operand), alpha_suffix);
} else if (use_broadcast_kernel) {
    kernel_name = fmt::format("{}_dense_broadcast{}_{}_{}{}",
        name, broadcast_on_lhs ? "_rhs" : "",
        scalarToMetalTypeString(out),
        scalarToMetalTypeString(tensor_operand), alpha_suffix);
} else {
    // 기존 dense/strided 경로
}

세 가지 경로로 분기한다: (1) scalar 연산 -- 한쪽이 단일 원소, (2) broadcast 연산 -- 한쪽이 broadcast 가능, (3) 기존 dense/strided 경로. 각각 전용 Metal 커널을 사용한다.

Metal 셰이더 커널

template <typename T, typename F, typename om_t = opmath_t<T>>
kernel void binary_dense_broadcast(
    device result_of<F, T, T>* out [[buffer(0)]],
    constant T* input [[buffer(1)]],
    constant T* broadcast [[buffer(2)]],
    constant long& broadcast_numel [[buffer(3)]],
    uint tid [[thread_position_in_grid]]) {
  F f;
  using res_t = result_of<F, T, T>;
  out[tid] = static_cast<res_t>(
      f(om_t(input[tid]), om_t(broadcast[tid % broadcast_numel])));
}

tid % broadcast_numel로 broadcast 텐서를 인덱싱하는 것이 핵심이다. Strided iterator가 필요 없으므로 shape/stride 계산 오버헤드가 사라진다.

왜 이게 좋은가

MPS에서 binary 연산(mul, add 등)에 broadcast가 포함되면, 기존에는 strided 커널을 사용했다. Strided 커널은 매 원소마다 다차원 인덱스를 계산해야 하므로 느리다. 전용 broadcast 커널은 modulo 연산 하나로 인덱싱을 해결하여, 특히 tensor * scalar 같은 빈번한 패턴에서 큰 성능 개선을 가져온다.

정리

  • is_dense_broadcastable 함수로 broadcast 가능 여부를 빠르게 판별한다
  • Scalar, broadcast, dense, strided 네 경로로 분기하여 최적 커널을 선택한다
  • binary_dense_broadcast, binary_dense_scalar 등 12종의 새 Metal 커널 템플릿이 추가되었다
  • LHS/RHS 양쪽 모두에 대한 broadcast 및 cast 변형도 지원한다

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글