[PyTorch] MPS mul 성능 회귀 수정
PR 링크: pytorch/pytorch#172106 상태: Merged | 변경: +635 / -57
들어가며
PyTorch의 MPS(Metal Performance Shaders) 백엔드에서 mul 연산의 성능 회귀가 발생했다. 원인은 broadcasting 상황에서 strided iterator를 사용하면서 불필요한 오버헤드가 생긴 것이다. 이 PR은 broadcast와 scalar 케이스를 감지하여 전용 Metal 셰이더 커널로 분기하는 최적화를 추가한다.
핵심 코드 분석
Broadcast 감지 로직
static bool is_dense_broadcastable(const Tensor& from, const Tensor& into) {
if (!from.is_contiguous() || !into.is_contiguous()) {
return false;
}
bool checking_squeezable_dims = false;
for (const auto dim : c10::irange(from.ndimension())) {
if (checking_squeezable_dims) {
if (from.size(-dim - 1) == 1) continue;
return false;
}
checking_squeezable_dims = from.size(-dim - 1) != into.size(-dim - 1);
}
return true;
}
두 텐서가 contiguous이고, 앞쪽 차원만 1인 경우(예: [1, 1, 64]를 [8, 8, 64]에 broadcast)를 감지한다. 이 조건이면 tid % broadcast_numel로 간단히 인덱싱할 수 있다.
Kernel 분기
Before:
const auto kernel_name = cast_needed
? fmt::format("{}_{}_cast_{}{}", name, suffix, scalarToMetalTypeString(out), alpha_suffix)
: fmt::format("{}_{}_{}_{}{}", name, suffix, scalarToMetalTypeString(out),
scalarToMetalTypeString(input), alpha_suffix);
After:
if (use_scalar_kernel) {
kernel_name = fmt::format("{}_dense_scalar{}_{}_{}{}",
name, lhs_suffix, scalarToMetalTypeString(out),
scalarToMetalTypeString(tensor_operand), alpha_suffix);
} else if (use_broadcast_kernel) {
kernel_name = fmt::format("{}_dense_broadcast{}_{}_{}{}",
name, broadcast_on_lhs ? "_rhs" : "",
scalarToMetalTypeString(out),
scalarToMetalTypeString(tensor_operand), alpha_suffix);
} else {
// 기존 dense/strided 경로
}
세 가지 경로로 분기한다: (1) scalar 연산 -- 한쪽이 단일 원소, (2) broadcast 연산 -- 한쪽이 broadcast 가능, (3) 기존 dense/strided 경로. 각각 전용 Metal 커널을 사용한다.
Metal 셰이더 커널
template <typename T, typename F, typename om_t = opmath_t<T>>
kernel void binary_dense_broadcast(
device result_of<F, T, T>* out [[buffer(0)]],
constant T* input [[buffer(1)]],
constant T* broadcast [[buffer(2)]],
constant long& broadcast_numel [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
F f;
using res_t = result_of<F, T, T>;
out[tid] = static_cast<res_t>(
f(om_t(input[tid]), om_t(broadcast[tid % broadcast_numel])));
}
tid % broadcast_numel로 broadcast 텐서를 인덱싱하는 것이 핵심이다. Strided iterator가 필요 없으므로 shape/stride 계산 오버헤드가 사라진다.
왜 이게 좋은가
MPS에서 binary 연산(mul, add 등)에 broadcast가 포함되면, 기존에는 strided 커널을 사용했다. Strided 커널은 매 원소마다 다차원 인덱스를 계산해야 하므로 느리다. 전용 broadcast 커널은 modulo 연산 하나로 인덱싱을 해결하여, 특히 tensor * scalar 같은 빈번한 패턴에서 큰 성능 개선을 가져온다.
정리
is_dense_broadcastable함수로 broadcast 가능 여부를 빠르게 판별한다- Scalar, broadcast, dense, strided 네 경로로 분기하여 최적 커널을 선택한다
binary_dense_broadcast,binary_dense_scalar등 12종의 새 Metal 커널 템플릿이 추가되었다- LHS/RHS 양쪽 모두에 대한 broadcast 및 cast 변형도 지원한다
참고 자료
- PyTorch MPS 백엔드 -- Apple Silicon GPU 가속
- Metal Shading Language -- Metal 셰이더 스펙
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Triton] AMD Gluon DSL에 TDM L2 Prefetch 노출 — 사용자 수준 프리페치 제어
- 현재글 : [PyTorch] MPS mul 성능 회귀 수정
- 다음글 [vllm] MORI KV Connector - ROCm 기반 Prefill-Decode Disaggregation
댓글