Skip to content

Commit

Permalink
【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API (#…
Browse files Browse the repository at this point in the history
…58917)

* [Init] add more split api

* [Update] update unittest

* [Add] add docstrings

* [Fix] fix merge

* [Change] tensor_split with split

* [Fix] remove out of range example

* [Fix] tensor_split docstring of supported data type

* [Change] _tensor_split_indices with slice

* [Change] resolve conflict

* [Change] h v d -split like tensor_split
  • Loading branch information
megemini authored Dec 13, 2023
1 parent 46e3dfe commit 538905c
Show file tree
Hide file tree
Showing 4 changed files with 895 additions and 146 deletions.
6 changes: 6 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,15 @@
concat,
crop,
diagonal_scatter,
dsplit,
expand,
expand_as,
flatten,
flip,
flip as reverse,
gather,
gather_nd,
hsplit,
index_add,
index_add_,
index_fill,
Expand Down Expand Up @@ -309,6 +311,7 @@
row_stack,
strided_slice,
take_along_axis,
tensor_split,
tensordot,
tile,
tolist,
Expand Down Expand Up @@ -631,6 +634,9 @@
'searchsorted',
'bucketize',
'split',
'tensor_split',
'hsplit',
'dsplit',
'vsplit',
'logical_and',
'logical_and_',
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
column_stack,
concat,
diagonal_scatter,
dsplit,
dstack,
expand,
expand_as,
Expand All @@ -158,6 +159,7 @@
flip as reverse,
gather,
gather_nd,
hsplit,
hstack,
index_add,
index_add_,
Expand Down Expand Up @@ -189,6 +191,7 @@
stack,
strided_slice,
take_along_axis,
tensor_split,
tensordot,
tile,
unbind,
Expand Down Expand Up @@ -608,6 +611,9 @@
'shard_index',
'slice',
'split',
'tensor_split',
'hsplit',
'dsplit',
'vsplit',
'chunk',
'tensordot',
Expand Down
241 changes: 223 additions & 18 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2571,17 +2571,227 @@ def _get_SectionsTensorList(one_list):
return outs


def vsplit(x, num_or_sections, name=None):
def tensor_split(x, num_or_indices, axis=0, name=None):
"""
Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.split`` with ``axis=0``.
Split the input tensor into multiple sub-Tensors along ``axis``, allowing not being of equal size.
Args:
x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, float16, float32, float64, uint8, int8, int32 or int64.
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections``
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into.
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of
sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly.
The length of the list must not be larger than the ``x`` 's size of axis 0.
x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections along ``axis``.
If ``x`` is divisible by ``n``, each section will be ``x.shape[axis] / n``. If ``x`` is not divisible by ``n``, the first
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n).
If ``num_or_indices`` is a list or tuple of integter indices, ``x`` is split along ``axis`` at each of the indices. For instance,
``num_or_indices=[2, 4]`` with ``axis=0`` would split ``x`` into ``x[:2]``, ``x[2:4]`` and ``x[4:]`` along axis 0.
axis (int|Tensor, optional): The axis along which to split, it can be a integer or a ``0-D Tensor``
with shape [] and data type ``int32`` or ``int64``.
If :math::`axis < 0`, the axis to split along is :math:`rank(x) + axis`. Default is 0.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
list[Tensor], The list of segmented Tensors.
Examples:
.. code-block:: python
>>> import paddle
>>> # x is a Tensor of shape [8]
>>> # evenly split
>>> x = paddle.rand([8])
>>> out0, out1 = paddle.tensor_split(x, num_or_indices=2)
>>> print(out0.shape)
[4]
>>> print(out1.shape)
[4]
>>> # not evenly split
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=3)
>>> print(out0.shape)
[3]
>>> print(out1.shape)
[3]
>>> print(out2.shape)
[2]
>>> # split with indices
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3])
>>> print(out0.shape)
[2]
>>> print(out1.shape)
[1]
>>> print(out2.shape)
[5]
>>> # split along axis
>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
>>> out0, out1 = paddle.tensor_split(x, num_or_indices=2, axis=1)
>>> print(out0.shape)
[7, 4]
>>> print(out1.shape)
[7, 4]
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3], axis=1)
>>> print(out0.shape)
[7, 2]
>>> print(out1.shape)
[7, 1]
>>> print(out2.shape)
[7, 5]
"""
if x.ndim <= 0 or x.ndim <= axis:
raise ValueError(
f"The input tensor's dimension must be greater than 0 or axis which is {axis}, but got {x.ndim}"
)

total_n = x.shape[axis]

def _tensor_split_indices(x, total_n, indices, axis):
splits = []

starts = 0
ends = 0
for idx in list(indices) + [total_n]:
ends = idx
# convert index < 0 to positive
starts_index = starts if starts >= 0 else total_n + starts
ends_index = ends if ends >= 0 else total_n + ends
# ends index should equal or larger than starts
ends_index = max(starts_index, ends_index)

sub_array = paddle.slice(
x, axes=[axis], starts=[starts_index], ends=[ends_index]
)
splits.append(sub_array)
starts = ends

return splits

def _tensor_split_sections(x, total_n, sections, axis):
if sections <= 0:
raise ValueError('num_or_indices must be larger than 0.')

base, mod = divmod(total_n, sections)
num_or_sections = [base + 1] * mod + [base] * (sections - mod)
return split(x, num_or_sections, axis)

if isinstance(num_or_indices, int):
return _tensor_split_sections(x, total_n, num_or_indices, axis)

elif isinstance(num_or_indices, (list, tuple)):
return _tensor_split_indices(x, total_n, num_or_indices, axis)

else:
raise ValueError(
f"The num_or_indices should be int, list or tuple of ints, but got {type(num_or_indices)}"
)


def hsplit(x, num_or_indices, name=None):
"""
Split the input tensor into multiple sub-Tensors along the horizontal axis, which is equivalent to ``paddle.tensor_split`` with ``axis=1``
when ``x`` 's dimension is larger than 1, or equivalent to ``paddle.tensor_split`` with ``axis=0`` when ``x`` 's dimension is 1.
Args:
x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections.
If ``num_or_indices`` is a list or tuple of integter indices, ``x`` is split at each of the indices.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
list[Tensor], The list of segmented Tensors.
Examples:
.. code-block:: python
>>> import paddle
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1 = paddle.hsplit(x, num_or_indices=2)
>>> print(out0.shape)
[4]
>>> print(out1.shape)
[4]
>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
>>> out0, out1 = paddle.hsplit(x, num_or_indices=2)
>>> print(out0.shape)
[7, 4]
>>> print(out1.shape)
[7, 4]
>>> out0, out1, out2 = paddle.hsplit(x, num_or_indices=[1, 4])
>>> print(out0.shape)
[7, 1]
>>> print(out1.shape)
[7, 3]
>>> print(out2.shape)
[7, 4]
"""
if x.ndim < 1:
raise ValueError(
f"The input tensor's dimension must be greater than 0, but got {x.ndim}"
)
if x.ndim > 1:
return tensor_split(x, num_or_indices, axis=1, name=name)
else:
return tensor_split(x, num_or_indices, axis=0, name=name)


def dsplit(x, num_or_indices, name=None):
"""
Split the input tensor into multiple sub-Tensors along the depth axis, which is equivalent to ``paddle.tensor_split`` with ``axis=2``.
Args:
x (Tensor): A Tensor whose dimension must be greater than 2. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections.
If ``num_or_indices`` is a list or tuple of integter indices, ``x`` is split at each of the indices.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
list[Tensor], The list of segmented Tensors.
Examples:
.. code-block:: python
>>> import paddle
>>> # x is a Tensor of shape [7, 6, 8]
>>> x = paddle.rand([7, 6, 8])
>>> out0, out1 = paddle.dsplit(x, num_or_indices=2)
>>> print(out0.shape)
[7, 6, 4]
>>> print(out1.shape)
[7, 6, 4]
>>> out0, out1, out2 = paddle.dsplit(x, num_or_indices=[1, 4])
>>> print(out0.shape)
[7, 6, 1]
>>> print(out1.shape)
[7, 6, 3]
>>> print(out2.shape)
[7, 6, 4]
"""
if x.ndim < 3:
raise ValueError(
f"The input tensor's dimension must be greater than 2, but got {x.ndim}"
)
return tensor_split(x, num_or_indices, axis=2, name=name)


def vsplit(x, num_or_indices, name=None):
"""
Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.tensor_split`` with ``axis=0``.
Args:
x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections.
If ``num_or_indices`` is a list or tuple of integter indices, ``x`` is split at each of the indices.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Expand All @@ -2594,31 +2804,26 @@ def vsplit(x, num_or_sections, name=None):
>>> # x is a Tensor of shape [8, 6, 7]
>>> x = paddle.rand([8, 6, 7])
>>> out0, out1 = paddle.vsplit(x, num_or_sections=2)
>>> out0, out1 = paddle.vsplit(x, num_or_indices=2)
>>> print(out0.shape)
[4, 6, 7]
>>> print(out1.shape)
[4, 6, 7]
>>> out0, out1, out2 = paddle.vsplit(x, num_or_sections=[1, 3, 4])
>>> out0, out1, out2 = paddle.vsplit(x, num_or_indices=[1, 4])
>>> print(out0.shape)
[1, 6, 7]
>>> print(out1.shape)
[3, 6, 7]
>>> print(out2.shape)
[4, 6, 7]
>>> out0, out1, out2 = paddle.vsplit(x, num_or_sections=[2, 3, -1])
>>> print(out0.shape)
[2, 6, 7]
>>> print(out1.shape)
[3, 6, 7]
>>> print(out2.shape)
[3, 6, 7]
"""
if x.ndim < 2:
raise ValueError(
f"The input tensor's dimension must be greater than 1, but got {x.ndim}"
)
return split(x, num_or_sections, axis=0, name=name)
return tensor_split(x, num_or_indices, axis=0, name=name)


def squeeze(x, axis=None, name=None):
Expand Down
Loading

0 comments on commit 538905c

Please sign in to comment.