diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 11a2d07d2096d..1d4e6cbac164e 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -247,6 +247,11 @@ index_add_, index_put, index_put_, + column_stack, + row_stack, + dstack, + hstack, + vstack, unflatten, as_strided, view, diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ce4cfc8ee883b..737349f2681c2 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -160,6 +160,11 @@ from .manipulation import index_add_ # noqa: F401 from .manipulation import index_put # noqa: F401 from .manipulation import index_put_ # noqa: F401 +from .manipulation import column_stack # noqa: F401 +from .manipulation import row_stack # noqa: F401 +from .manipulation import dstack # noqa: F401 +from .manipulation import hstack # noqa: F401 +from .manipulation import vstack # noqa: F401 from .manipulation import unflatten # noqa: F401 from .manipulation import as_strided # noqa: F401 from .manipulation import view # noqa: F401 @@ -669,6 +674,11 @@ "index_add_", 'index_put', 'index_put_', + 'column_stack', + 'row_stack', + 'dstack', + 'hstack', + 'vstack', 'take', 'bucketize', 'sgn', diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index ae61880c997be..bead809772172 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -5124,6 +5124,302 @@ def index_put(x, indices, value, accumulate=False, name=None): return out +def column_stack(x, name=None): + """ + Stacks 1-D tensors as columns into a 2-D tensor. + First, 1-D arrays are turned into 2-D columns, and then 2-D arrays are stacked as-is, just like with hstack. + + Args: + x(tuple[Tensor] or list[Tensor]): A sequence of tensors to concatenate. + name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + 2-D Tensor, formed by stacking the given tensors. + + Raises: + TypeError: If `tensors` is not list or tuple. + TypeError: If element in `tensors` is not Tensor. + ValueError: If `tensors` is empty. + + Examples: + .. code-block:: python + >>> import paddle + >>> a = paddle.to_tensor([1, 2, 3]) + >>> b = paddle.to_tensor([4, 5, 6]) + >>> c = paddle.column_stack((a, b)) + >>> print(c) + Tensor(shape=[3, 2], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1, 4], + [2, 5], + [3, 6]]) + >>> a = paddle.arange(5) + >>> b = paddle.arange(10).reshape([5, 2]) + >>> c = paddle.column_stack((a, b, b)) + >>> print(c) + Tensor(shape=[5, 5], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 0, 1, 0, 1], + [1, 2, 3, 2, 3], + [2, 4, 5, 4, 5], + [3, 6, 7, 6, 7], + [4, 8, 9, 8, 9]]) + """ + check_type(x, 'x', (list, tuple), 'column_stack') + if not x: + msg = "For 'column_stack', inputs can not be empty" + raise TypeError(msg) + trans_x = () + + for tensor in x: + check_type(x, 'x', (list, tuple), 'column_stack') + if tensor.ndim < 1: + tensor = paddle.unsqueeze(tensor, 0) + if tensor.ndim == 1: + tensor = paddle.unsqueeze(tensor, 1) + trans_x += (tensor,) + if not trans_x: + raise ValueError( + "For column_stack, the input must have at least 1 tensor, but got 0." + ) + + return paddle.concat(trans_x, 1) + + +def dstack(x, name=None): + """ + Stacks tensors along the third axis in sequence depthwise. + + First, 1-D tensors should be reshaped to :math:`(1,N,1)`, 2-D tensors should be reshaped to :math:`(M,N,1)`. And then concatenation along the third axis. + + Args: + x(tuple[Tensor] or list[Tensor]): A sequence of tensors to concatenate. + name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + Stacked Tensor, will be at least 3-D. + + Raises: + TypeError: If `inputs` is not tuple or list. + ValueError: If `inputs` is empty. + + Examples: + .. code-block:: python + >>> import paddle + >>> a = paddle.to_tensor([1, 2, 3]) + >>> b = paddle.to_tensor([4, 5, 6]) + >>> c = paddle.dstack((a,b)) + >>> print(c) + Tensor(shape=[1, 3, 2], dtype=int64, place=Place(cpu), stop_gradient=True, + [[[1, 4], + [2, 5], + [3, 6]]]) + >>> a = paddle.to_tensor([[1],[2],[3]]) + >>> b = paddle.to_tensor([[4],[5],[6]]) + >>> c = paddle.dstack((a,b)) + >>> print(c) + Tensor(shape=[3, 1, 2], dtype=int64, place=Place(cpu), stop_gradient=True, + [[[1, 4]], + [[2, 5]], + [[3, 6]]]) + """ + check_type(x, 'x', (list, tuple), 'dstack') + if not x: + msg = "For 'dstack', inputs can not be empty" + raise TypeError(msg) + rep = () + for tensor in x: + check_type( + tensor, + 'tensor', + (Variable), + 'dstack', + f"For 'dstack', each elements of 'inputs' must be Tensor, but got {type(tensor)}", + ) + if tensor.size == 0: + raise TypeError( + "For 'dstack', each elements of 'inputs' can not be empty." + ) + ndim = tensor.ndim + # similar to the function of atleast_3d + if ndim == 0: + tensor = paddle.reshape(tensor, (1, 1, 1)) + if ndim == 1: + size = tensor.shape[0] + tensor = paddle.reshape(tensor, (1, size, 1)) + if ndim == 2: + tensor = paddle.unsqueeze(tensor, axis=-1) + rep += (tensor,) + if not rep: + raise ValueError( + "For 'dstack', at least one tensor is needed to concatenate." + ) + return paddle.concat(rep, 2) + + +def hstack(x, name=None): + """ + Stacks tensors in sequence horizontally(column wise). + For for 1-D tensors,it concatenates along the first axis.The others concatenation along the second axis. + + Args: + x(tuple[Tensor] or list[Tensor]): A sequence of tensors to concatenate. + name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + Stacked Tensor, formed by stacking the given tensors. + + Raises: + TypeError: If `tensors` is not list or tuple. + TypeError: If element in `tensors` is not Tensor. + ValueError: If `tensors` is empty. + + Examples: + .. code-block:: python + >>> import paddle + >>> a = paddle.to_tensor([1, 2, 3]) + >>> b = paddle.to_tensor([4, 5, 6]) + >>> c = paddle.hstack((a,b)) + >>> print(c) + Tensor(shape=[6], dtype=int64, place=Place(cpu), stop_gradient=True, + [1, 2, 3, 4, 5, 6]) + >>> a = paddle.to_tensor([[1],[2],[3]]) + >>> b = paddle.to_tensor([[4],[5],[6]]) + >>> c = paddle.hstack((a,b)) + >>> print(c) + Tensor(shape=[3, 2], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1, 4], + [2, 5], + [3, 6]]) + """ + check_type(x, 'x', (list, tuple), 'hstack') + if not x: + msg = "For 'hstack', inputs can not be empty" + raise TypeError(msg) + rep = () + for tensor in x: + check_type( + tensor, + 'tensor', + (Variable), + 'hstack', + f"For 'hstack', each element of 'inputs' must be a tensor, but got {type(tensor)}", + ) + # similar to the atleast_1d + if tensor.ndim == 0: + tensor = paddle.reshape(tensor, [1]) + rep += (tensor,) + if not rep: + raise ValueError( + "For 'hstack', need at least one tensor to concatenate." + ) + + if rep[0].dim() == 1: + return paddle.concat(rep, 0) + + return paddle.concat(rep, 1) + + +def vstack(x, name=None): + """ + Stacks tensors in sequence vertically(row wise). + First, 1-D arrays of shape (N,) have been reshaped to (1,N).And then oncatenation along the first axis. + + Args: + x(tuple[Tensor] or list[Tensor]): A sequence of tensors to concatenate. + name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + tuple[Tensor] or list[Tensor]: A sequence of tensors to concatenate + + Raises: + TypeError: If `inputs` is not list or tuple. + ValueError: If `inputs` is empty. + + Examples: + .. code-block:: python + >>> import paddle + >>> a = paddle.to_tensor([1, 2, 3]) + >>> b = paddle.to_tensor([4, 5, 6]) + >>> c = paddle.vstack((a,b)) + >>> print(c) + Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1, 2, 3], + [4, 5, 6]]) + >>> a = paddle.to_tensor([[1],[2],[3]]) + >>> b = paddle.to_tensor([[4],[5],[6]]) + >>> c = paddle.vstack((a,b)) + >>> print(c) + Tensor(shape=[6, 1], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1], + [2], + [3], + [4], + [5], + [6]]) + """ + check_type(x, 'x', (list, tuple), 'vstack') + if not x: + msg = "For 'hstack', inputs can not be empty" + raise TypeError(msg) + + rep = () + for tensor in x: + check_type( + tensor, + 'tensor', + (Variable), + 'hstack', + f"For 'hstack', each element of 'inputs' must be a tensor, but got {type(tensor)}", + ) + # similar to the function of atleast_2d + if tensor.ndim == 0: + tensor = paddle.reshape(tensor, [1, 1]) + elif tensor.ndim == 1: + tensor = paddle.unsqueeze(tensor, 0) + rep += (tensor,) + if not rep: + raise ValueError( + "For 'hstack', need at least one tensor to concatenate." + ) + return paddle.concat(rep, 0) + + +def row_stack(x, name=None): + """ + Alias for :func:`paddle.vstack`. + + Args: + x(tuple[Tensor] or list[Tensor]): A sequence of tensors to concatenate. + name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + tuple[Tensor] or list[Tensor]: A sequence of tensors to concatenate + + Examples: + .. code-block:: python + >>> import paddle + >>> a = paddle.to_tensor([1, 2, 3]) + >>> b = paddle.to_tensor([4, 5, 6]) + >>> c = paddle.row_stack((a,b)) + >>> print(c) + Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1, 2, 3], + [4, 5, 6]]) + >>> a = paddle.to_tensor([[1],[2],[3]]) + >>> b = paddle.to_tensor([[4],[5],[6]]) + >>> c = paddle.row_stack((a,b)) + >>> print(c) + Tensor(shape=[6, 1], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1], + [2], + [3], + [4], + [5], + [6]]) + """ + rep = paddle.vstack(x) + return rep + + def unflatten(x, axis, shape, name=None): """ Expand a certain dimension of the input x Tensor into a desired shape. diff --git a/test/legacy_test/test_vhs_row_column_stack.py b/test/legacy_test/test_vhs_row_column_stack.py new file mode 100644 index 0000000000000..038cd8445c416 --- /dev/null +++ b/test/legacy_test/test_vhs_row_column_stack.py @@ -0,0 +1,296 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + + +class StackBase(unittest.TestCase): + def setUp(self): + self.x = np.array(1.0, dtype="float64") + self.x_shape = [] + self.x_dtype = "float64" + self.y = np.array(1.0, dtype="float64") + self.y_shape = [] + self.y_dtype = "float64" + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def static_api(self, func): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data( + name='x', shape=self.x_shape, dtype=self.x_dtype + ) + y = paddle.static.data( + name='y', shape=self.y_shape, dtype=self.y_dtype + ) + f = getattr(paddle, func) + out = f((x, y)) + exe = paddle.static.Executor(self.place) + res = exe.run( + feed={ + 'x': self.x, + 'y': self.y, + }, + fetch_list=[out], + ) + f = getattr(np, func) + expect_output = f((self.x, self.y)) + np.testing.assert_allclose(expect_output, res[0], atol=1e-05) + + def dygraph_api(self, func): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + f = getattr(paddle, func) + res = f((x, y)) + f = getattr(np, func) + expect_output = f((self.x, self.y)) + + np.testing.assert_allclose(expect_output, res.numpy(), atol=1e-05) + paddle.enable_static() + + +class TestColumnStackTwo(StackBase): + def setUp(self): + # 2 tensor stack + self.x = np.array([1, 2, 3], dtype="int32") + self.x_shape = [3] + self.x_dtype = "int32" + self.y = np.array([4, 5, 6], dtype="int32") + self.y_shape = [3] + self.y_dtype = "int32" + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static_api(self): + super().static_api('column_stack') + + def test_dygraph_api(self): + super().dygraph_api('column_stack') + + +class TestColumnStackThree(StackBase): + def setUp(self): + # three tensor stack + self.x = np.array([0, 1, 2, 3, 4], dtype="int32") + self.x_shape = [5] + self.x_dtype = "int32" + self.y = np.array( + [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], dtype="int32" + ) + self.y_shape = [5, 2] + self.y_dtype = "int32" + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static_api(self): + super().static_api('column_stack') + + def test_dygraph_api(self): + super().dygraph_api('column_stack') + + +class TestHStack1D2T(StackBase): + def setUp(self): + # 2 tensor with 1 dim + self.x = np.array([1, 2, 3], dtype="int32") + self.x_shape = [3] + self.x_dtype = "int32" + self.y = np.array([4, 5, 6], dtype="int32") + self.y_shape = [3] + self.y_dtype = "int32" + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static_api(self): + super().static_api('hstack') + + def test_dygraph_api(self): + super().dygraph_api('hstack') + + +class TestHStack2D2T(StackBase): + def setUp(self): + # 2 tensor with 2 dim + self.x = np.array([[1], [2], [3]], dtype="int32") + self.x_shape = [3, 1] + self.x_dtype = "int32" + self.y = np.array([[4], [5], [6]], dtype="int32") + self.y_shape = [3, 1] + self.y_dtype = "int32" + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static_api(self): + super().static_api('hstack') + + def test_dygraph_api(self): + super().dygraph_api('hstack') + + +class TestDStack1D2T(StackBase): + def setUp(self): + # 2 tensor with 1 dim + self.x = np.array([1, 2, 3], dtype="int32") + self.x_shape = [3] + self.x_dtype = "int32" + self.y = np.array([4, 5, 6], dtype="int32") + self.y_shape = [3] + self.y_dtype = "int32" + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static_api(self): + super().static_api('dstack') + + def test_dygraph_api(self): + super().dygraph_api('dstack') + + +class TestDStack2D2T(StackBase): + def setUp(self): + # 2 tensor with 2 dim + self.x = np.array([[1], [2], [3]], dtype="int32") + self.x_shape = [3, 1] + self.x_dtype = "int32" + self.y = np.array([[4], [5], [6]], dtype="int32") + self.y_shape = [3, 1] + self.y_dtype = "int32" + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static_api(self): + super().static_api('dstack') + + def test_dygraph_api(self): + super().dygraph_api('dstack') + + +class TestRowStack1D2T(StackBase): + def setUp(self): + # 2 tensor with 1 dim + self.x = np.array([1, 2, 3], dtype="int32") + self.x_shape = [3] + self.x_dtype = "int32" + self.y = np.array([4, 5, 6], dtype="int32") + self.y_shape = [3] + self.y_dtype = "int32" + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static_api(self): + super().static_api('row_stack') + + def test_dygraph_api(self): + super().dygraph_api('row_stack') + + +class TestRowStack2D2T(StackBase): + def setUp(self): + # 2 tensor with 2 dim + self.x = np.array([[1], [2], [3]], dtype="int32") + self.x_shape = [3, 1] + self.x_dtype = "int32" + self.y = np.array([[4], [5], [6]], dtype="int32") + self.y_shape = [3, 1] + self.y_dtype = "int32" + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static_api(self): + super().static_api('row_stack') + + def test_dygraph_api(self): + super().dygraph_api('row_stack') + + +class TestVStack1D2T(StackBase): + def setUp(self): + # 2 tensor with 1 dim + self.x = np.array([1, 2, 3], dtype="int32") + self.x_shape = [3] + self.x_dtype = "int32" + self.y = np.array([4, 5, 6], dtype="int32") + self.y_shape = [3] + self.y_dtype = "int32" + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static_api(self): + super().static_api('vstack') + + def test_dygraph_api(self): + super().dygraph_api('vstack') + + +class TestVStack2D2T(StackBase): + def setUp(self): + # 2 tensor with 2 dim + self.x = np.array([[1], [2], [3]], dtype="int32") + self.x_shape = [3, 1] + self.x_dtype = "int32" + self.y = np.array([[4], [5], [6]], dtype="int32") + self.y_shape = [3, 1] + self.y_dtype = "int32" + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static_api(self): + super().static_api('vstack') + + def test_dygraph_api(self): + super().dygraph_api('vstack') + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main()