Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
【Hackathon 5th No.33】为 Paddle 新增 atleast_1d / atleast_2d / atleast_3d…
Browse files Browse the repository at this point in the history
… API -part (PaddlePaddle#58323)

* [Init] add atleast api

* [Add] add atleast test

* [Fix] import atleast

* [Change] test_atleast.py to test_atleast_nd.py and add bool data type test

* [Update] update dtype supports and unittest

* [Fix] dtype error unittest

* [Change] static test with test_with_pir_api

* [Add] atleast_Nd as tensor method
megemini authored and SecretXV committed Nov 28, 2023
1 parent c765bd2 commit b6dbb2b
Showing 4 changed files with 648 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
@@ -199,6 +199,9 @@


from .tensor.manipulation import ( # noqa: F401
atleast_1d,
atleast_2d,
atleast_3d,
cast,
cast_,
concat,
@@ -833,6 +836,9 @@
'logspace',
'reshape',
'reshape_',
'atleast_1d',
'atleast_2d',
'atleast_3d',
'reverse',
'nonzero',
'CUDAPinnedPlace',
6 changes: 6 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -112,6 +112,9 @@
from .logic import isclose # noqa: F401
from .logic import equal_all # noqa: F401
from .logic import is_tensor # noqa: F401
from .manipulation import atleast_1d # noqa: F401
from .manipulation import atleast_2d # noqa: F401
from .manipulation import atleast_3d # noqa: F401
from .manipulation import cast # noqa: F401
from .manipulation import cast_ # noqa: F401
from .manipulation import concat # noqa: F401
@@ -731,6 +734,9 @@
'normal_',
'index_fill',
'index_fill_',
'atleast_1d',
'atleast_2d',
'atleast_3d',
]

# this list used in math_op_patch.py for magic_method bind
178 changes: 178 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
@@ -3955,6 +3955,184 @@ def reshape_(x, shape, name=None):
return out


def atleast_1d(*inputs, name=None):
"""
Convert inputs to tensors and return the view with at least 1-dimension. Scalar inputs are converted,
one or high-dimensional inputs are preserved.
Args:
inputs (Tensor|list(Tensor)): One or more tensors. The data type is ``float16``, ``float32``, ``float64``, ``int16``, ``int32``, ``int64``, ``int8``, ``uint8``, ``complex64``, ``complex128``, ``bfloat16`` or ``bool``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
One Tensor, if there is only one input.
List of Tensors, if there are more than one inputs.
Examples:
.. code-block:: python
>>> import paddle
>>> # one input
>>> x = paddle.to_tensor(123, dtype='int32')
>>> out = paddle.atleast_1d(x)
>>> print(out)
Tensor(shape=[1], dtype=int32, place=Place(cpu), stop_gradient=True,
[123])
>>> # more than one inputs
>>> x = paddle.to_tensor(123, dtype='int32')
>>> y = paddle.to_tensor([1.23], dtype='float32')
>>> out = paddle.atleast_1d(x, y)
>>> print(out)
[Tensor(shape=[1], dtype=int32, place=Place(cpu), stop_gradient=True,
[123]), Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
[1.23000002])]
>>> # more than 1-D input
>>> x = paddle.to_tensor(123, dtype='int32')
>>> y = paddle.to_tensor([[1.23]], dtype='float32')
>>> out = paddle.atleast_1d(x, y)
>>> print(out)
[Tensor(shape=[1], dtype=int32, place=Place(cpu), stop_gradient=True,
[123]), Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
[[1.23000002]])]
"""
out = []
for tensor in inputs:
tensor = paddle.to_tensor(tensor)
if tensor.dim() == 0:
result = tensor.reshape((1,))
else:
result = tensor
out.append(result)

if len(out) == 1:
return out[0]
else:
return out


def atleast_2d(*inputs, name=None):
"""
Convert inputs to tensors and return the view with at least 2-dimension. Two or high-dimensional inputs are preserved.
Args:
inputs (Tensor|list(Tensor)): One or more tensors. The data type is ``float16``, ``float32``, ``float64``, ``int16``, ``int32``, ``int64``, ``int8``, ``uint8``, ``complex64``, ``complex128``, ``bfloat16`` or ``bool``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
One Tensor, if there is only one input.
List of Tensors, if there are more than one inputs.
Examples:
.. code-block:: python
>>> import paddle
>>> # one input
>>> x = paddle.to_tensor(123, dtype='int32')
>>> out = paddle.atleast_2d(x)
>>> print(out)
Tensor(shape=[1, 1], dtype=int32, place=Place(cpu), stop_gradient=True,
[[123]])
>>> # more than one inputs
>>> x = paddle.to_tensor(123, dtype='int32')
>>> y = paddle.to_tensor([1.23], dtype='float32')
>>> out = paddle.atleast_2d(x, y)
>>> print(out)
[Tensor(shape=[1, 1], dtype=int32, place=Place(cpu), stop_gradient=True,
[[123]]), Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
[[1.23000002]])]
>>> # more than 2-D input
>>> x = paddle.to_tensor(123, dtype='int32')
>>> y = paddle.to_tensor([[[1.23]]], dtype='float32')
>>> out = paddle.atleast_2d(x, y)
>>> print(out)
[Tensor(shape=[1, 1], dtype=int32, place=Place(cpu), stop_gradient=True,
[[123]]), Tensor(shape=[1, 1, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[1.23000002]]])]
"""
out = []
for tensor in inputs:
tensor = paddle.to_tensor(tensor)
if tensor.dim() == 0:
result = tensor.reshape((1, 1))
elif tensor.dim() == 1:
result = paddle.unsqueeze(tensor, axis=0)
else:
result = tensor
out.append(result)

if len(out) == 1:
return out[0]
else:
return out


def atleast_3d(*inputs, name=None):
"""
Convert inputs to tensors and return the view with at least 3-dimension. Three or high-dimensional inputs are preserved.
Args:
inputs (Tensor|list(Tensor)): One or more tensors. The data type is ``float16``, ``float32``, ``float64``, ``int16``, ``int32``, ``int64``, ``int8``, ``uint8``, ``complex64``, ``complex128``, ``bfloat16`` or ``bool``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
One Tensor, if there is only one input.
List of Tensors, if there are more than one inputs.
Examples:
.. code-block:: python
>>> import paddle
>>> # one input
>>> x = paddle.to_tensor(123, dtype='int32')
>>> out = paddle.atleast_3d(x)
>>> print(out)
Tensor(shape=[1, 1, 1], dtype=int32, place=Place(cpu), stop_gradient=True,
[[[123]]])
>>> # more than one inputs
>>> x = paddle.to_tensor(123, dtype='int32')
>>> y = paddle.to_tensor([1.23], dtype='float32')
>>> out = paddle.atleast_3d(x, y)
>>> print(out)
[Tensor(shape=[1, 1, 1], dtype=int32, place=Place(cpu), stop_gradient=True,
[[[123]]]), Tensor(shape=[1, 1, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[1.23000002]]])]
>>> # more than 3-D input
>>> x = paddle.to_tensor(123, dtype='int32')
>>> y = paddle.to_tensor([[[[1.23]]]], dtype='float32')
>>> out = paddle.atleast_3d(x, y)
>>> print(out)
[Tensor(shape=[1, 1, 1], dtype=int32, place=Place(cpu), stop_gradient=True,
[[[123]]]), Tensor(shape=[1, 1, 1, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[[1.23000002]]]])]
"""
out = []
for tensor in inputs:
tensor = paddle.to_tensor(tensor)
if tensor.dim() == 0:
result = tensor.reshape((1, 1, 1))
elif tensor.dim() == 1:
result = paddle.unsqueeze(tensor, axis=[0, 2])
elif tensor.dim() == 2:
result = paddle.unsqueeze(tensor, axis=2)
else:
result = tensor
out.append(result)

if len(out) == 1:
return out[0]
else:
return out


def gather_nd(x, index, name=None):
"""
458 changes: 458 additions & 0 deletions test/legacy_test/test_atleast_nd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,458 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import parameterized as param

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

RTOL = 1e-5
ATOL = 1e-8

PLACES = [paddle.CPUPlace()] + (
[paddle.CUDAPlace(0)] if core.is_compiled_with_cuda() else []
)


def func_ref(func, *inputs):
"""ref func, just for convenience"""
return func(*inputs)


test_list = [
(paddle.atleast_1d, np.atleast_1d),
(paddle.atleast_2d, np.atleast_2d),
(paddle.atleast_3d, np.atleast_3d),
]


def generate_data(ndim, count=1, max_size=4, mix=False, dtype='int32'):
"""generate test data
Args:
ndim(int): dim of inputs
count(int): input count for each dim
max_size(int): max size for each dim
mix(bool): mix data types or not, like a data list [123, np.array(123), paddle.to_tensor(123), ...]
dtype(str): dtype
Returns:
a list of data like:
[[data, dtype, shape, name], [data, dtype, shape, name] ... ]
"""

rtn = []
for d in range(ndim):
data = [
np.random.randint(
0,
255,
size=[np.random.randint(1, max_size) for _ in range(d)],
dtype=dtype,
)
for _ in range(count)
]

if mix:

def _mix_data(data, idx):
if idx % 3 == 0:
return data.tolist()
elif idx % 3 == 1:
return data
elif idx % 3 == 2:
return paddle.to_tensor(data)

# mix normal/numpy/tensor
rtn.append(
list(
zip(
*[
[
_mix_data(_data, idx),
str(_data.dtype),
_data.shape,
'{}d_{}_{}'.format(d, idx, 'mix'),
]
for idx, _data in enumerate(data)
]
)
)
)

else:
# normal
rtn.append(
list(
zip(
*[
[
_data.tolist(),
str(_data.dtype),
_data.shape,
'{}d_{}_{}'.format(d, idx, 'normal'),
]
for idx, _data in enumerate(data)
]
)
)
)
# numpy
rtn.append(
list(
zip(
*[
[
_data,
str(_data.dtype),
_data.shape,
'{}d_{}_{}'.format(d, idx, 'numpy'),
]
for idx, _data in enumerate(data)
]
)
)
)
# tensor
rtn.append(
list(
zip(
*[
[
paddle.to_tensor(_data),
str(_data.dtype),
_data.shape,
'{}d_{}_{}'.format(d, idx, 'tensor'),
]
for idx, _data in enumerate(data)
]
)
)
)
return rtn


class BaseTest(unittest.TestCase):
"""Test in each `PLACES`, each `test_list`, and in `static/dygraph`"""

@test_with_pir_api
def _test_static_api(
self,
inputs: list,
dtypes: list,
shapes: list,
names: list,
):
"""Test `static`, convert `Tensor` to `numpy array` before feed into graph"""
for place in PLACES:
paddle.enable_static()
for func, func_type in test_list:
with paddle.static.program_guard(paddle.static.Program()):
x = []
feed = {}
for i in range(len(inputs)):
input = inputs[i]
shape = shapes[i]
dtype = dtypes[i]
name = names[i]
x.append(paddle.static.data(name, shape, dtype))
# the data feeded should NOT be a Tensor
feed[name] = (
input.numpy()
if isinstance(input, paddle.Tensor)
else input
)

out = func(*x)
exe = paddle.static.Executor(place)
res = exe.run(feed=feed, fetch_list=[out])

# unwrap inputs when lenght 1
if len(inputs) == 1:
res = res[0]

out_ref = func_ref(
func_type,
*[
input.numpy()
if isinstance(input, paddle.Tensor)
else input
for input in inputs
]
)

for n, p in zip(out_ref, res):
np.testing.assert_allclose(n, p, rtol=RTOL, atol=ATOL)

def _test_dygraph_api(
self,
inputs: list,
dtypes: list,
shapes: list,
names: list,
):
"""Test `dygraph`, and check grads"""
for place in PLACES:
paddle.disable_static(place)
for func, func_type in test_list:
out = func(*inputs)
out_ref = func_ref(
func_type,
*[
input.numpy()
if isinstance(input, paddle.Tensor)
else input
for input in inputs
]
)

for n, p in zip(out_ref, out):
np.testing.assert_allclose(
n, p.numpy(), rtol=RTOL, atol=ATOL
)

# check grads
if len(inputs) == 1:
out = [out]

for y in out:
y.stop_gradient = False
z = y * 123
grads = paddle.grad(z, y)
self.assertTrue(len(grads), 1)
self.assertEqual(grads[0].dtype, y.dtype)
self.assertEqual(grads[0].shape, y.shape)


@param.parameterized_class(
('inputs', 'dtypes', 'shapes', 'names'),
(generate_data(5, count=1, max_size=4, dtype='int32')),
)
class TestAtleastDim(BaseTest):
"""test dim from 0 to 5"""

def test_all(self):
self._test_dygraph_api(
self.inputs, self.dtypes, self.shapes, self.names
)
self._test_static_api(self.inputs, self.dtypes, self.shapes, self.names)


@param.parameterized_class(
('inputs', 'dtypes', 'shapes', 'names'),
(generate_data(5, count=3, max_size=4, dtype='int32')),
)
class TestAtleastDimMoreInputs(BaseTest):
"""test inputs of 3 tensors"""

def test_all(self):
self._test_dygraph_api(
self.inputs, self.dtypes, self.shapes, self.names
)
self._test_static_api(self.inputs, self.dtypes, self.shapes, self.names)


@param.parameterized_class(
('inputs', 'dtypes', 'shapes', 'names'),
(generate_data(5, count=5, max_size=4, mix=True, dtype='int32')),
)
class TestAtleastMixData(BaseTest):
"""test mix number/numpy/tensor"""

def test_all(self):
self._test_dygraph_api(
self.inputs, self.dtypes, self.shapes, self.names
)
self._test_static_api(self.inputs, self.dtypes, self.shapes, self.names)


@param.parameterized_class(
('inputs', 'dtypes', 'shapes', 'names'),
(
(
(
123,
np.array([123], dtype='int32'),
paddle.to_tensor([[123]], dtype='int32'),
[[[123]]],
np.array([[[[123]]]], dtype='int32'),
paddle.to_tensor([[[[[123]]]]], dtype='int32'),
),
('int32', 'int32', 'int32', 'int32', 'int32', 'int32'),
((), (1,), (1, 1), (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1)),
(
'0_mixdim',
'1_mixdim',
'2_mixdim',
'3_mixdim',
'4_mixdim',
'5_mixdim',
),
),
),
)
class TestAtleastMixDim(BaseTest):
"""test mix dim"""

def test_all(self):
self._test_dygraph_api(
self.inputs, self.dtypes, self.shapes, self.names
)
self._test_static_api(self.inputs, self.dtypes, self.shapes, self.names)


@param.parameterized_class(
('inputs', 'dtypes', 'shapes', 'names'),
(
(
(
paddle.to_tensor(True, dtype='bool'),
paddle.to_tensor(0.1, dtype='float16'),
paddle.to_tensor(0.1, dtype='float32'),
paddle.to_tensor(0.1, dtype='float64'),
paddle.to_tensor(1, dtype='int8'),
paddle.to_tensor(1, dtype='int16'),
paddle.to_tensor(1, dtype='int32'),
paddle.to_tensor(1, dtype='int64'),
paddle.to_tensor(1, dtype='uint8'),
paddle.to_tensor(1 + 1j, dtype='complex64'),
paddle.to_tensor(1 + 1j, dtype='complex128'),
paddle.to_tensor(0.1, dtype='bfloat16'),
),
(
'bool',
'float16',
'float32',
'float64',
'int8',
'int16',
'int32',
'int64',
'uint8',
'complex64',
'complex128',
'bfloat16',
),
(
(),
(),
(),
(),
(),
(),
(),
(),
(),
(),
(),
(),
),
(
'0_mixdtype',
'1_mixdtype',
'2_mixdtype',
'3_mixdtype',
'4_mixdtype',
'5_mixdtype',
'6_mixdtype',
'7_mixdtype',
'8_mixdtype',
'9_mixdtype',
'10_mixdtype',
'11_mixdtype',
),
),
),
)
class TestAtleastMixDtypes(BaseTest):
"""test mix dtypes"""

def test_all(self):
self._test_dygraph_api(
self.inputs, self.dtypes, self.shapes, self.names
)
self._test_static_api(self.inputs, self.dtypes, self.shapes, self.names)


@param.parameterized_class(
('inputs', 'dtypes', 'shapes', 'names'),
(
(((123, [123]),), ('int32',), ((),), ('0_combine',)),
(
((np.array([123], dtype='int32'), [[123]]),),
('int32',),
((),),
('1_combine',),
),
(
(
(
np.array([[123]], dtype='int32'),
paddle.to_tensor([[[123]]], dtype='int32'),
),
),
('int32',),
((),),
('2_combine',),
),
),
)
class TestAtleastErrorCombineInputs(BaseTest):
"""test combine inputs, like: `at_leastNd((x, y))`, where paddle treats like numpy"""

def test_all(self):
with self.assertRaises(ValueError):
self._test_dygraph_api(
self.inputs, self.dtypes, self.shapes, self.names
)

with self.assertRaises(ValueError):
self._test_static_api(
self.inputs, self.dtypes, self.shapes, self.names
)


class TestAtleastAsTensorMethod(unittest.TestCase):
def test_as_tensor_method(self):
input = 123
tensor = paddle.to_tensor(input)

for place in PLACES:
paddle.disable_static(place)

out = tensor.atleast_1d()
out_ref = np.atleast_1d(input)

for n, p in zip(out_ref, out):
np.testing.assert_allclose(n, p.numpy(), rtol=RTOL, atol=ATOL)

out = tensor.atleast_2d()
out_ref = np.atleast_2d(input)

for n, p in zip(out_ref, out):
np.testing.assert_allclose(n, p.numpy(), rtol=RTOL, atol=ATOL)

out = tensor.atleast_3d()
out_ref = np.atleast_3d(input)

for n, p in zip(out_ref, out):
np.testing.assert_allclose(n, p.numpy(), rtol=RTOL, atol=ATOL)


if __name__ == '__main__':
unittest.main()

0 comments on commit b6dbb2b

Please sign in to comment.