Skip to content

Commit

Permalink
[Fix IndexError] add unstack axis check (PaddlePaddle#49943)
Browse files Browse the repository at this point in the history
* add unstack axis check

* IndexErr -> ValueError

* add static select
  • Loading branch information
DrRyanHuang authored and pangengzheng committed Feb 2, 2023
1 parent 7ccc7dc commit 12c0b48
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
24 changes: 24 additions & 0 deletions python/paddle/fluid/tests/unittests/test_unstack_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,29 @@ def initParameters(self):
self.axis = 2


class TestUnstackZeroInputOp(unittest.TestCase):
def unstack_zero_input_static(self):

paddle.enable_static()

array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0]), dtype='float32')
paddle.unstack(x, axis=1)

def unstack_zero_input_dynamic(self):

array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0]), dtype='float32')
paddle.unstack(x, axis=1)

def test_type_error(self):
paddle.disable_static()

self.assertRaises(ValueError, self.unstack_zero_input_dynamic)
self.assertRaises(ValueError, self.unstack_zero_input_static)

paddle.disable_static()


if __name__ == '__main__':
unittest.main()
4 changes: 4 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,10 @@ def unstack(x, axis=0, num=None):
y = paddle.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5]
"""
if not (-x.ndim <= axis < x.ndim):
raise ValueError(
'`axis` must be in the range [-{0}, {0})'.format(x.ndim)
)
if in_dygraph_mode():
if num is None:
num = x.shape[axis]
Expand Down

0 comments on commit 12c0b48

Please sign in to comment.