Skip to content

Commit

Permalink
Implement aten::upsample_linear1d on mps (#115031)
Browse files Browse the repository at this point in the history
Related to #77764

Co-authored-by: Nikita Shulga <[email protected]>
Pull Request resolved: #115031
Approved by: https://github.com/malfet
  • Loading branch information
kaieberl authored and pytorchmergebot committed Feb 26, 2024
1 parent 30625ae commit c59b141
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 21 deletions.
70 changes: 50 additions & 20 deletions aten/src/ATen/native/mps/operations/UpSample.mm
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
#include <ATen/ops/upsample_bilinear2d_backward.h>
#include <ATen/ops/upsample_bilinear2d_backward_native.h>
#include <ATen/ops/upsample_bilinear2d_native.h>
#include <ATen/ops/upsample_linear1d.h>
#include <ATen/ops/upsample_linear1d_backward.h>
#include <ATen/ops/upsample_linear1d_backward_native.h>
#include <ATen/ops/upsample_linear1d_native.h>
#include <ATen/ops/upsample_nearest1d.h>
#include <ATen/ops/upsample_nearest1d_backward.h>
#include <ATen/ops/upsample_nearest1d_backward_native.h>
Expand All @@ -36,9 +40,9 @@
// supported resize_mode: 'nearest' | 'bilinear' | 'nearest-exact'
static void upsample_out_template(const Tensor& input,
IntArrayRef output_size,
c10::optional<IntArrayRef> input_size_opt, // only used for backward pass
c10::optional<double> scale_h_opt,
c10::optional<double> scale_w_opt,
std::optional<IntArrayRef> input_size_opt, // only used for backward pass
std::optional<double> scale_h_opt,
std::optional<double> scale_w_opt,
const Tensor& output,
bool align_corners,
const c10::string_view resize_mode_str) {
Expand Down Expand Up @@ -235,7 +239,7 @@ static void upsample_out_template(const Tensor& input,

} // namespace mps

static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10::optional<double> scale) {
static bool check_mps_compatibility(const c10::string_view resize_mode_str, std::optional<double> scale) {
static const bool is_macOS_13_0_or_newer = is_macos_13_or_newer();
if (!is_macOS_13_0_or_newer) {
// passing scale factors to MPS's resize APIs is not supported on macOS < 13
Expand All @@ -258,7 +262,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
}

TORCH_IMPL_FUNC(upsample_nearest1d_out_mps)
(const Tensor& input, IntArrayRef output_size, c10::optional<double> scale, const Tensor& output) {
(const Tensor& input, IntArrayRef output_size, std::optional<double> scale, const Tensor& output) {
if (check_mps_compatibility("nearest", scale)) {
mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest");
} else {
Expand All @@ -270,7 +274,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
(const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scale,
std::optional<double> scale,
const Tensor& grad_input) {
if (check_mps_compatibility("nearest", scale)) {
mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest");
Expand All @@ -280,7 +284,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
}

TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps)
(const Tensor& input, IntArrayRef output_size, c10::optional<double> scale, const Tensor& output) {
(const Tensor& input, IntArrayRef output_size, std::optional<double> scale, const Tensor& output) {
if (check_mps_compatibility("nearest-exact", scale)) {
mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest-exact");
} else {
Expand All @@ -292,7 +296,7 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
(const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scale,
std::optional<double> scale,
const Tensor& grad_input) {
if (check_mps_compatibility("nearest-exact", scale)) {
mps::upsample_out_template(
Expand All @@ -305,8 +309,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
TORCH_IMPL_FUNC(upsample_nearest2d_out_mps)
(const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
std::optional<double> scales_h,
std::optional<double> scales_w,
const Tensor& output) {
if (check_mps_compatibility("nearest", scales_w)) {
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest");
Expand All @@ -319,8 +323,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
(const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
std::optional<double> scales_h,
std::optional<double> scales_w,
const Tensor& grad_input) {
if (check_mps_compatibility("nearest", scales_w)) {
mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest");
Expand All @@ -333,8 +337,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps)
(const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
std::optional<double> scales_h,
std::optional<double> scales_w,
const Tensor& output) {
if (check_mps_compatibility("nearest-exact", scales_w)) {
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest-exact");
Expand All @@ -347,8 +351,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
(const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
std::optional<double> scales_h,
std::optional<double> scales_w,
const Tensor& grad_input) {
if (check_mps_compatibility("nearest-exact", scales_w)) {
mps::upsample_out_template(
Expand All @@ -359,12 +363,38 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
}
}

TORCH_IMPL_FUNC(upsample_linear1d_out_mps)
(const Tensor& input, IntArrayRef output_size, bool align_corners, std::optional<double> scale, const Tensor& output) {
if (check_mps_compatibility("bilinear", scale)) {
mps::upsample_out_template(
input, output_size, c10::nullopt, c10::nullopt, scale, output, align_corners, "bilinear");
} else {
output.copy_(at::upsample_linear1d(input.to("cpu"), output_size, align_corners, scale));
}
}

TORCH_IMPL_FUNC(upsample_linear1d_backward_out_mps)
(const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
bool align_corners,
std::optional<double> scale,
const Tensor& grad_input) {
if (check_mps_compatibility("bilinear", scale)) {
mps::upsample_out_template(
grad_output, output_size, input_size, c10::nullopt, scale, grad_input, align_corners, "bilinear");
} else {
grad_input.copy_(
at::upsample_linear1d_backward(grad_output.to("cpu"), output_size, input_size, align_corners, scale));
}
}

TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps)
(const Tensor& input,
IntArrayRef output_size,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
std::optional<double> scales_h,
std::optional<double> scales_w,
const Tensor& output) {
if (check_mps_compatibility("bilinear", scales_w)) {
mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, align_corners, "bilinear");
Expand All @@ -378,8 +408,8 @@ static bool check_mps_compatibility(const c10::string_view resize_mode_str, c10:
IntArrayRef output_size,
IntArrayRef input_size,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
std::optional<double> scales_h,
std::optional<double> scales_w,
const Tensor& grad_input) {
if (check_mps_compatibility("bilinear", scales_w)) {
mps::upsample_out_template(
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12401,6 +12401,7 @@
dispatch:
CPU: upsample_linear1d_out_cpu
CUDA: upsample_linear1d_out_cuda
MPS: upsample_linear1d_out_mps

- func: upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor
python_module: nn
Expand All @@ -12412,6 +12413,7 @@
dispatch:
CPU: upsample_linear1d_backward_out_cpu
CUDA: upsample_linear1d_backward_out_cuda
MPS: upsample_linear1d_backward_out_mps

- func: upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor
python_module: nn
Expand Down
1 change: 0 additions & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,6 @@ def mps_ops_modifier(ops):
'nn.functional.adaptive_max_pool3d': None,
'nn.functional.interpolatearea': None,
'nn.functional.interpolatebicubic': None,
'nn.functional.interpolatelinear': None,
'nn.functional.interpolatetrilinear': None,
# TODO: max_pool2d for integral types fails the numerical test
'nn.functional.max_pool2d': (integral_types() if product_version < 14.0 else
Expand Down

0 comments on commit c59b141

Please sign in to comment.