diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index d5748145ffe49..c05d51d64ba08 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -413,6 +413,17 @@ func : conv3d_transpose_grad data_type : x +- backward_op : copysign_grad + forward : copysign (Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : copysign_grad + inplace : (out_grad -> x_grad) + - backward_op : cos_double_grad forward : cos_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x) args : (Tensor x, Tensor grad_out, Tensor grad_x_grad) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml old mode 100644 new mode 100755 index 65ca863db5d4b..43294b7a4d7c2 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -578,6 +578,16 @@ data_type : x backward : conv3d_transpose_grad +- op : copysign + args : (Tensor x, Tensor y) + output : Tensor(out) + infer_meta : + func : ElementwiseInferMeta + kernel : + func : copysign + inplace: (x -> out) + backward : copysign_grad + - op : cos args : (Tensor x) output : Tensor(out) diff --git a/paddle/phi/kernels/copysign_grad_kernel.h b/paddle/phi/kernels/copysign_grad_kernel.h new file mode 100755 index 0000000000000..38c44dc4bd6a0 --- /dev/null +++ b/paddle/phi/kernels/copysign_grad_kernel.h @@ -0,0 +1,53 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" + +namespace phi { + +using float16 = phi::dtype::float16; +using bfloat16 = phi::dtype::bfloat16; + +template +inline HOSTDEVICE auto copysign_func(const T& a, const T& b) { +#ifdef WIN32 + using U = typename std::conditional_t::value, float, T>; + return static_cast(std::copysign(static_cast(a), static_cast(b))); +#else + return static_cast(std::copysign(a, b)); +#endif +} + +inline HOSTDEVICE phi::dtype::float16 copysign_func(phi::dtype::float16 a, + phi::dtype::float16 b) { + return phi::dtype::raw_uint16_to_float16((a.x & 0x7fff) | (b.x & 0x8000)); +} + +inline HOSTDEVICE phi::dtype::bfloat16 copysign_func(phi::dtype::bfloat16 a, + phi::dtype::bfloat16 b) { + return phi::dtype::raw_uint16_to_bfloat16((a.x & 0x7fff) | (b.x & 0x8000)); +} + +template +void CopySignGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad); +} // namespace phi diff --git a/paddle/phi/kernels/copysign_kernel.h b/paddle/phi/kernels/copysign_kernel.h new file mode 100644 index 0000000000000..68b0b5463924d --- /dev/null +++ b/paddle/phi/kernels/copysign_kernel.h @@ -0,0 +1,63 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +namespace phi { + +using float16 = phi::dtype::float16; +using bfloat16 = phi::dtype::bfloat16; + +template +inline HOSTDEVICE auto copysign_func(const T& a, const T& b) { +#ifdef WIN32 + using U = typename std::conditional_t::value, float, T>; + return static_cast(std::copysign(static_cast(a), static_cast(b))); +#else + return static_cast(std::copysign(a, b)); +#endif +} + +inline HOSTDEVICE phi::dtype::float16 copysign_func(phi::dtype::float16 a, + phi::dtype::float16 b) { + return phi::dtype::raw_uint16_to_float16((a.x & 0x7fff) | (b.x & 0x8000)); +} + +inline HOSTDEVICE phi::dtype::bfloat16 copysign_func(phi::dtype::bfloat16 a, + phi::dtype::bfloat16 b) { + return phi::dtype::raw_uint16_to_bfloat16((a.x & 0x7fff) | (b.x & 0x8000)); +} + +template +struct CopySignFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { + return copysign_func(a, b); + } +}; +template +struct InverseCopySignFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { + return copysign_func(b, a); + } +}; + +template +void CopySignKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/kernels/cpu/copysign_grad_kernel.cc b/paddle/phi/kernels/cpu/copysign_grad_kernel.cc new file mode 100644 index 0000000000000..5f803e9309c8b --- /dev/null +++ b/paddle/phi/kernels/cpu/copysign_grad_kernel.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/copysign_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/elementwise_grad.h" + +namespace phi { + +template +HOSTDEVICE T compute_copysign_grad_dx(T x, T y, T out, T dout) { + if (x == static_cast(0)) + return x; + else + return static_cast(dout * (phi::copysign_func(x, y) / x)); +} + +template +struct CopySignGradDX { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return compute_copysign_grad_dx(x, y, out, dout); + } +}; + +template +struct CopySignGradDY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return static_cast(0); + } +}; + +template +void CopySignGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + funcs::ElementwiseGradPreProcess(out_grad, x_grad); + int axis = -1; + phi::funcs:: + ElemwiseGradCompute, CopySignGradDY>( + dev_ctx, + x, + y, + out_grad, + out_grad, + axis, + x_grad, + y_grad, + CopySignGradDX(), + CopySignGradDY()); +} +} // namespace phi + +PD_REGISTER_KERNEL(copysign_grad, + CPU, + ALL_LAYOUT, + phi::CopySignGradKernel, + bool, + uint8_t, + int8_t, + int16_t, + int, + int64_t, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/copysign_kernel.cc b/paddle/phi/kernels/cpu/copysign_kernel.cc new file mode 100755 index 0000000000000..6df11bedd3a91 --- /dev/null +++ b/paddle/phi/kernels/cpu/copysign_kernel.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/copysign_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +namespace phi { +template +void CopySignKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + dev_ctx.template Alloc(out); + auto x_dims = x.dims(); + auto y_dims = y.dims(); + if (x_dims.size() >= y_dims.size()) { + funcs::ElementwiseCompute, T>( + dev_ctx, x, y, phi::CopySignFunctor(), out); + } else { + funcs::ElementwiseCompute, T>( + dev_ctx, x, y, phi::InverseCopySignFunctor(), out); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(copysign, + CPU, + ALL_LAYOUT, + phi::CopySignKernel, + bool, + uint8_t, + int8_t, + int16_t, + int, + int64_t, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/copysign_grad_kernel.cu b/paddle/phi/kernels/gpu/copysign_grad_kernel.cu new file mode 100644 index 0000000000000..7a9accf2fce80 --- /dev/null +++ b/paddle/phi/kernels/gpu/copysign_grad_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/copysign_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/elementwise_grad.h" + +namespace phi { + +template +struct CopySignGradXFunctor { + inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const { + if (x == static_cast(0)) return x; + return dout * (phi::copysign_func(x, y) / x); + } +}; + +template +struct CopySignGradYFunctor { + inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const { + return static_cast(0); + } +}; + +template +struct CopySignGradXYFunctor { + inline HOSTDEVICE phi::Array operator()(const InT x, + const InT y, + const InT dout) { + phi::Array outs; + // dx + if (x == static_cast(0)) + outs[0] = static_cast(0); + else + outs[0] = static_cast(dout * (phi::copysign_func(x, y)) / x); + // dy = 0 + outs[1] = static_cast(0); + return outs; + } +}; + +template +void CopySignGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + const auto place = dev_ctx.GetPlace(); + int axis = -1; + if (x_grad != nullptr && y_grad != nullptr) { + std::vector ins = {&x, &y, &out_grad}; + GetGradXAndYOut(dev_ctx, + place, + axis, + ins, + out_grad, + x_grad, + y_grad, + CopySignGradXYFunctor()); + } else if (x_grad != nullptr && y_grad == nullptr) { + std::vector ins = {&x, &y, &out_grad}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, out_grad, x_grad, CopySignGradXFunctor()); + } else if (y_grad != nullptr && x_grad == nullptr) { + std::vector ins = {&x, &y, &out_grad}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, out_grad, y_grad, CopySignGradYFunctor()); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(copysign_grad, + GPU, + ALL_LAYOUT, + phi::CopySignGradKernel, + bool, + uint8_t, + int8_t, + int16_t, + int, + int64_t, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/copysign_kernel.cu b/paddle/phi/kernels/gpu/copysign_kernel.cu new file mode 100755 index 0000000000000..dd5ceda09bafb --- /dev/null +++ b/paddle/phi/kernels/gpu/copysign_kernel.cu @@ -0,0 +1,46 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/copysign_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +namespace phi { +template +void CopySignKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + std::vector inputs = {&x, &y}; + std::vector outputs = {out}; + dev_ctx.template Alloc(out); + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, phi::CopySignFunctor()); +} +} // namespace phi + +PD_REGISTER_KERNEL(copysign, + GPU, + ALL_LAYOUT, + phi::CopySignKernel, + bool, + uint8_t, + int8_t, + int16_t, + int, + int64_t, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 153fb194c1177..454a3291ec3d5 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -367,6 +367,8 @@ clip, combinations, conj, + copysign, + copysign_, cos, cos_, cosh, @@ -948,6 +950,8 @@ 'i1e', 'polygamma', 'polygamma_', + 'copysign', + 'copysign_', 'bitwise_left_shift', 'bitwise_left_shift_', 'bitwise_right_shift', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 330a5132e06f2..0ab10993b8aa7 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -246,6 +246,8 @@ clip_, combinations, conj, + copysign, + copysign_, cos, cos_, cosh, @@ -785,6 +787,8 @@ 'asinh_', 'diag', 'normal_', + 'copysign', + 'copysign_', 'normal_', 'bitwise_left_shift', 'bitwise_left_shift_', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 94d685f299edf..a050b6175780e 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -16,6 +16,7 @@ """ import math +import warnings import numpy as np @@ -7287,6 +7288,107 @@ def bitwise_right_shift_(x, y, is_arithmetic=True, out=None, name=None): return _C_ops.bitwise_right_shift_(x, y, is_arithmetic) +def copysign(x, y, name=None): + r""" + Create a new floating-point tensor with the magnitude of input ``x`` and the sign of ``y``, elementwise. + + Equation: + .. math:: + + copysign(x_{i},y_{i})=\left\{\begin{matrix} + & -|x_{i}| & if \space y_{i} <= -0.0\\ + & |x_{i}| & if \space y_{i} >= 0.0 + \end{matrix}\right. + + Args: + x (Tensor): The input Tensor, magnitudes, the data type is bool, uint8, int8, int16, int32, int64, bfloat16, float16, float32, float64. + y (Tensor, number): contains value(s) whose signbit(s) are applied to the magnitudes in input. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor), the output tensor. The data type is the same as the input tensor. + + Examples: + .. code-block:: python + :name: example1 + + >>> import paddle + >>> x = paddle.to_tensor([1, 2, 3], dtype='float64') + >>> y = paddle.to_tensor([-1, 1, -1], dtype='float64') + >>> out = paddle.copysign(x, y) + >>> print(out) + Tensor(shape=[3], dtype=float64, place=Place(gpu:0), stop_gradient=True, + [-1., 2., -3.]) + + .. code-block:: python + :name: example2 + + >>> x = paddle.to_tensor([1, 2, 3], dtype='float64') + >>> y = paddle.to_tensor([-2], dtype='float64') + >>> res = paddle.copysign(x, y) + >>> print(res) + Tensor(shape=[3], dtype=float64, place=Place(gpu:0), stop_gradient=True, + [-1., -2., -3.]) + + .. code-block:: python + :name: example_zero1 + + >>> x = paddle.to_tensor([1, 2, 3], dtype='float64') + >>> y = paddle.to_tensor([0.0], dtype='float64') + >>> out = paddle.copysign(x, y) + >>> print(out) + Tensor(shape=[3], dtype=float64, place=Place(gpu:0), stop_gradient=True, + [1., 2., 3.]) + + .. code-block:: python + :name: example_zero2 + + >>> x = paddle.to_tensor([1, 2, 3], dtype='float64') + >>> y = paddle.to_tensor([-0.0], dtype='float64') + >>> out = paddle.copysign(x, y) + >>> print(out) + Tensor(shape=[3], dtype=float64, place=Place(gpu:0), stop_gradient=True, + [-1., -2., -3.]) + """ + if isinstance(y, (float, int)): + y = paddle.to_tensor(y, dtype=x.dtype) + out_shape = broadcast_shape(x.shape, y.shape) + if out_shape != x.shape: + warnings.warn( + "The shape of broadcast output {} is different from the input tensor x with shape: {}, please make sure you are using copysign api correctly.".format( + out_shape, x.shape + ) + ) + + if in_dynamic_or_pir_mode(): + return _C_ops.copysign(x, y) + else: + helper = LayerHelper("copysign", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='copysign', inputs={'x': x, 'y': y}, outputs={'out': out} + ) + return out + + +@inplace_apis_in_dygraph_only +def copysign_(x, y, name=None): + r""" + Inplace version of ``copysign`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_copysign`. + """ + if isinstance(y, (float, int)): + y = paddle.to_tensor(y, dtype=x.dtype) + out_shape = broadcast_shape(x.shape, y.shape) + if out_shape != x.shape: + raise ValueError( + "The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format( + out_shape, x.shape + ) + ) + return _C_ops.copysign_(x, y) + + def hypot(x, y, name=None): """ Calculate the length of the hypotenuse of a right-angle triangle. The equation is: diff --git a/test/legacy_test/test_copysign_op.py b/test/legacy_test/test_copysign_op.py new file mode 100755 index 0000000000000..acfc4c0222386 --- /dev/null +++ b/test/legacy_test/test_copysign_op.py @@ -0,0 +1,303 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest, convert_float_to_uint16 + +import paddle +from paddle.base import core +from paddle.pir_utils import test_with_pir_api + +np.random.seed(100) +paddle.seed(100) + + +def ref_copysign(x, y): + out_dtype = x.dtype + return np.copysign(x, y).astype(out_dtype) + + +def ref_grad_copysign(x, y, dout): + out = np.copysign(x, y) + return dout * out / x + + +class TestCopySignOp(OpTest): + def setUp(self): + self.op_type = "copysign" + self.python_api = paddle.copysign + self.init_config() + self.inputs = {'x': self.x, 'y': self.y} + self.target = ref_copysign(self.inputs['x'], self.inputs['y']) + self.outputs = {'out': self.target} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['x', 'y'], ['out']) + + def test_check_grad_ignore_x(self): + self.check_grad(['y'], ['out']) + + def test_check_grad_ignore_y(self): + self.check_grad(['x'], ['out']) + + def init_config(self): + self.x = np.random.randn(20, 6).astype('float64') + self.y = np.random.randn(20, 6).astype('float64') + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestCopySignBF16(OpTest): + def setUp(self): + self.op_type = "copysign" + self.python_api = paddle.copysign + self.init_dtype() + np.random.seed(1024) + x = np.random.randn(20, 6).astype(np.float32) + y = np.random.randn(20, 6).astype(np.float32) + out = ref_copysign(x, y) + self.inputs = { + 'x': convert_float_to_uint16(x), + 'y': convert_float_to_uint16(y), + } + self.outputs = {'out': convert_float_to_uint16(out)} + self.place = core.CUDAPlace(0) + + def init_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['x', 'y'], ['out']) + + def test_check_grad_ignore_x(self): + self.check_grad_with_place( + self.place, ['y'], ['out'], no_grad_set=set('x') + ) + + def test_check_grad_ignore_y(self): + self.check_grad_with_place( + self.place, ['x'], ['out'], no_grad_set=set('y') + ) + + +class TestCopySignAPI(unittest.TestCase): + def setUp(self): + self.input_init() + self.place_init() + + def input_init(self): + self.x = np.random.randn(20, 6).astype('float64') + self.y = np.random.randn(20, 6).astype('float64') + + def place_init(self): + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + @test_with_pir_api + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data( + name='x', shape=self.x.shape, dtype=self.x.dtype + ) + if isinstance(self.y, (float, int)): + y = self.y + else: + y = paddle.static.data( + name='y', shape=self.y.shape, dtype=self.x.dtype + ) + out = paddle.copysign(x, y) + exe = paddle.static.Executor(self.place) + if isinstance(self.y, (float, int)): + res = exe.run( + paddle.static.default_main_program(), + feed={"x": self.x}, + fetch_list=[out], + ) + else: + res = exe.run( + paddle.static.default_main_program(), + feed={"x": self.x, "y": self.y}, + fetch_list=[out], + ) + + out_ref = ref_copysign(self.x, self.y) + np.testing.assert_allclose(out_ref, res[0]) + out_ref_dtype = out_ref.dtype + np.testing.assert_equal((out_ref_dtype == res[0].dtype), True) + paddle.disable_static() + + def test_dygraph_api(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + out = paddle.copysign(x, y) + out_ref = ref_copysign(self.x, self.y) + np.testing.assert_allclose(out_ref, out.numpy()) + out_ref_dtype = out_ref.dtype + np.testing.assert_equal((out_ref_dtype == out.numpy().dtype), True) + paddle.enable_static() + + +class TestCopySignBool(TestCopySignAPI): + def input_init(self): + dtype = np.bool_ + self.x = (np.random.randn(10, 20) * 10).astype(dtype) + self.y = (np.random.randn(10, 20) * 10).astype(dtype) + + +class TestCopySignUint8(TestCopySignAPI): + def input_init(self): + dtype = np.uint8 + self.x = (np.random.randn(10, 20) * 10).astype(dtype) + self.y = (np.random.randn(10, 20) * 10).astype(dtype) + + +class TestCopySignInt8(TestCopySignAPI): + def input_init(self): + dtype = np.int8 + self.x = (np.random.randn(10, 20) * 10).astype(dtype) + self.y = (np.random.randn(10, 20) * 10).astype(dtype) + + +class TestCopySignInt16(TestCopySignAPI): + def input_init(self): + dtype = np.int16 + self.x = (np.random.randn(10, 20) * 10).astype(dtype) + self.y = (np.random.randn(10, 20) * 10).astype(dtype) + + +class TestCopySignInt32(TestCopySignAPI): + def input_init(self): + dtype = np.int32 + self.x = (np.random.randn(10, 20) * 10).astype(dtype) + self.y = (np.random.randn(10, 20) * 10).astype(dtype) + + +class TestCopySignInt64(TestCopySignAPI): + def input_init(self): + dtype = np.int64 + self.x = (np.random.randn(10, 20) * 10).astype(dtype) + self.y = (np.random.randn(10, 20) * 10).astype(dtype) + + +class TestCopySignFloat16(TestCopySignAPI): + def input_init(self): + dtype = np.float16 + self.x = (np.random.randn(10, 20) * 10).astype(dtype) + self.y = (np.random.randn(10, 20) * 10).astype(dtype) + + +class TestCopySignFloat32(TestCopySignAPI): + def input_init(self): + dtype = np.float32 + self.x = (np.random.randn(10, 20) * 10).astype(dtype) + self.y = (np.random.randn(10, 20) * 10).astype(dtype) + + +class TestCopySignFloat64(TestCopySignAPI): + def input_init(self): + dtype = np.float64 + self.x = (np.random.randn(10, 20) * 10).astype(dtype) + self.y = (np.random.randn(10, 20) * 10).astype(dtype) + + +class TestCopySignNumberY(TestCopySignAPI): + def input_init(self): + dtype = np.float32 + self.x = (np.random.randn(10, 20) * 10).astype(dtype) + self.y = -2.0 + + +class TestCopySignZeroCase1(TestCopySignAPI): + def input_init(self): + self.x = np.zeros(shape=(10, 20)) + self.y = np.zeros(shape=(10, 20)) + + +class TestCopySignZeroCase2(TestCopySignAPI): + def input_init(self): + self.x = np.zeros(shape=(10, 20)) + self.y = np.random.randn(10, 20) + + +class TestCopySignZeroCase3(TestCopySignAPI): + def input_init(self): + self.x = np.random.randn(10, 20) + self.y = np.zeros(shape=(10, 20)) + + +class TestCopySignZeroDimCase1(TestCopySignAPI): + def input_init(self): + self.x = np.random.randn(0, 0) + self.y = np.random.randn(0, 0) + + +class TestCopySignZeroDimCase2(TestCopySignAPI): + def input_init(self): + self.x = np.random.randn(0, 5, 10) + self.y = np.random.randn(0, 5, 10) + + +class TestCopySignSpecialZeroCase1(TestCopySignAPI): + def input_init(self): + self.x = np.array([1, 2, 3]).astype(np.float32) + self.y = np.array([0, +0, -0]).astype(np.float32) + + +class TestCopySignSpecialZeroCase2(TestCopySignAPI): + def input_init(self): + self.x = np.array([0, +0, -0]).astype(np.float32) + self.y = np.array([1, 2, 3]).astype(np.float32) + + +class TestCopySignBroadcastCase1(TestCopySignAPI): + def input_init(self): + dtype = np.float16 + self.x = (np.random.randn(3, 4, 5) * 10).astype(dtype) + self.y = (np.random.randn(5) * 10).astype(dtype) + + +class TestCopySignBroadcastCase2(TestCopySignAPI): + def input_init(self): + dtype = np.float16 + self.x = (np.random.randn(3, 4, 5) * 10).astype(dtype) + self.y = (np.random.randn(4, 5) * 10).astype(dtype) + + +class TestCopySignBroadcastCase3(TestCopySignAPI): + def input_init(self): + dtype = np.float16 + self.x = (np.random.randn(4, 5) * 10).astype(dtype) + self.y = (np.random.randn(3, 4, 5) * 10).astype(dtype) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py old mode 100644 new mode 100755 index ce2526638e450..5f9fcb7be1e64 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -380,6 +380,30 @@ def test_continuous_inplace_backward(self): self.assertTrue(self.np_compare(grad_var_a_inplace, grad_var_a)) +class TestDygraphInplaceCopysign(TestDygraphInplace): + def init_data(self): + self.input_var_numpy = np.random.randn(10, 20) + self.dtype = "float32" + self.y = -3.0 + + def inplace_api_processing(self, var): + return paddle.copysign_(var, self.y) + + def non_inplace_api_processing(self, var): + return paddle.copysign(var, self.y) + + def test_leaf_inplace_var_error(self): + with paddle.base.dygraph.guard(): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + var.stop_gradient = False + self.y = paddle.rand([2, 10, 20]) + + def leaf_inplace_error(): + self.inplace_api_processing(var) + + self.assertRaises(ValueError, leaf_inplace_error) + + class TestDygraphInplaceUnsqueeze(TestDygraphInplace): def non_inplace_api_processing(self, var): return paddle.unsqueeze(var, -1)