diff --git a/aten/src/ATen/native/mps/operations/UpSample.mm b/aten/src/ATen/native/mps/operations/UpSample.mm index ee4e30d46cfaeb..f4973f60001561 100644 --- a/aten/src/ATen/native/mps/operations/UpSample.mm +++ b/aten/src/ATen/native/mps/operations/UpSample.mm @@ -20,6 +20,10 @@ #include #include #include +#include +#include +#include +#include #include #include #include @@ -36,9 +40,9 @@ // supported resize_mode: 'nearest' | 'bilinear' | 'nearest-exact' static void upsample_out_template(const Tensor& input, IntArrayRef output_size, - c10::optional input_size_opt, // only used for backward pass - c10::optional scale_h_opt, - c10::optional scale_w_opt, + std::optional input_size_opt, // only used for backward pass + std::optional scale_h_opt, + std::optional scale_w_opt, const Tensor& output, bool align_corners, const c10::string_view resize_mode_str) { @@ -235,7 +239,7 @@ static void upsample_out_template(const Tensor& input, } // namespace mps -static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10::optional scale) { +static bool check_mps_compatibility(const c10::string_view resize_mode_str, std::optional scale) { static const bool is_macOS_13_0_or_newer = is_macos_13_or_newer(); if (!is_macOS_13_0_or_newer) { // passing scale factors to MPS's resize APIs is not supported on macOS < 13 @@ -258,7 +262,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: } TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) -(const Tensor& input, IntArrayRef output_size, c10::optional scale, const Tensor& output) { +(const Tensor& input, IntArrayRef output_size, std::optional scale, const Tensor& output) { if (check_mps_compatibility("nearest", scale)) { mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest"); } else { @@ -270,7 +274,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: (const Tensor& grad_output, IntArrayRef output_size, IntArrayRef input_size, - c10::optional scale, + std::optional scale, const Tensor& grad_input) { if (check_mps_compatibility("nearest", scale)) { mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest"); @@ -280,7 +284,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: } TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps) -(const Tensor& input, IntArrayRef output_size, c10::optional scale, const Tensor& output) { +(const Tensor& input, IntArrayRef output_size, std::optional scale, const Tensor& output) { if (check_mps_compatibility("nearest-exact", scale)) { mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest-exact"); } else { @@ -292,7 +296,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: (const Tensor& grad_output, IntArrayRef output_size, IntArrayRef input_size, - c10::optional scale, + std::optional scale, const Tensor& grad_input) { if (check_mps_compatibility("nearest-exact", scale)) { mps::upsample_out_template( @@ -305,8 +309,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: TORCH_IMPL_FUNC(upsample_nearest2d_out_mps) (const Tensor& input, IntArrayRef output_size, - c10::optional scales_h, - c10::optional scales_w, + std::optional scales_h, + std::optional scales_w, const Tensor& output) { if (check_mps_compatibility("nearest", scales_w)) { mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest"); @@ -319,8 +323,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: (const Tensor& grad_output, IntArrayRef output_size, IntArrayRef input_size, - c10::optional scales_h, - c10::optional scales_w, + std::optional scales_h, + std::optional scales_w, const Tensor& grad_input) { if (check_mps_compatibility("nearest", scales_w)) { mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest"); @@ -333,8 +337,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps) (const Tensor& input, IntArrayRef output_size, - c10::optional scales_h, - c10::optional scales_w, + std::optional scales_h, + std::optional scales_w, const Tensor& output) { if (check_mps_compatibility("nearest-exact", scales_w)) { mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest-exact"); @@ -347,8 +351,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: (const Tensor& grad_output, IntArrayRef output_size, IntArrayRef input_size, - c10::optional scales_h, - c10::optional scales_w, + std::optional scales_h, + std::optional scales_w, const Tensor& grad_input) { if (check_mps_compatibility("nearest-exact", scales_w)) { mps::upsample_out_template( @@ -359,12 +363,38 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: } } +TORCH_IMPL_FUNC(upsample_linear1d_out_mps) +(const Tensor& input, IntArrayRef output_size, bool align_corners, std::optional scale, const Tensor& output) { + if (check_mps_compatibility("bilinear", scale)) { + mps::upsample_out_template( + input, output_size, c10::nullopt, c10::nullopt, scale, output, align_corners, "bilinear"); + } else { + output.copy_(at::upsample_linear1d(input.to("cpu"), output_size, align_corners, scale)); + } +} + +TORCH_IMPL_FUNC(upsample_linear1d_backward_out_mps) +(const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scale, + const Tensor& grad_input) { + if (check_mps_compatibility("bilinear", scale)) { + mps::upsample_out_template( + grad_output, output_size, input_size, c10::nullopt, scale, grad_input, align_corners, "bilinear"); + } else { + grad_input.copy_( + at::upsample_linear1d_backward(grad_output.to("cpu"), output_size, input_size, align_corners, scale)); + } +} + TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) (const Tensor& input, IntArrayRef output_size, bool align_corners, - c10::optional scales_h, - c10::optional scales_w, + std::optional scales_h, + std::optional scales_w, const Tensor& output) { if (check_mps_compatibility("bilinear", scales_w)) { mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, align_corners, "bilinear"); @@ -378,8 +408,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10: IntArrayRef output_size, IntArrayRef input_size, bool align_corners, - c10::optional scales_h, - c10::optional scales_w, + std::optional scales_h, + std::optional scales_w, const Tensor& grad_input) { if (check_mps_compatibility("bilinear", scales_w)) { mps::upsample_out_template( diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8a58e5b9928fd7..4518f71349e828 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -12401,6 +12401,7 @@ dispatch: CPU: upsample_linear1d_out_cpu CUDA: upsample_linear1d_out_cuda + MPS: upsample_linear1d_out_mps - func: upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor python_module: nn @@ -12412,6 +12413,7 @@ dispatch: CPU: upsample_linear1d_backward_out_cpu CUDA: upsample_linear1d_backward_out_cuda + MPS: upsample_linear1d_backward_out_mps - func: upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor python_module: nn diff --git a/test/test_mps.py b/test/test_mps.py index 19b1eac86a39c0..b5f54a62ef0738 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -723,7 +723,6 @@ def mps_ops_modifier(ops): 'nn.functional.adaptive_max_pool3d': None, 'nn.functional.interpolatearea': None, 'nn.functional.interpolatebicubic': None, - 'nn.functional.interpolatelinear': None, 'nn.functional.interpolatetrilinear': None, # TODO: max_pool2d for integral types fails the numerical test 'nn.functional.max_pool2d': (integral_types() if product_version < 14.0 else