diff --git a/CHANGELOG.md b/CHANGELOG.md index ec4c7162d6..ed732683ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ - [#653](https://github.com/helmholtz-analytics/heat/pull/653) Printing above threshold gathers the data without a buffer now - [#653](https://github.com/helmholtz-analytics/heat/pull/653) Bugfixes: Update unittests argmax & argmin + force index order in mpi_argmax & mpi_argmin. Add device parameter for tensor creation in dndarray.get_halo(). - [#664](https://github.com/helmholtz-analytics/heat/pull/664) New feature / enhancement: `random.random_sample`, `random.random`, `random.sample`, `random.ranf`, `random.random_integer` +- [#667](https://github.com/helmholtz-analytics/heat/pull/667) Enhancement `reshape`: rename axis parameter # v0.4.0 diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index db8fb7a321..32f382669b 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -911,7 +911,7 @@ def hstack(tup): return concatenate(tup, axis=axis) -def reshape(a, shape, axis=None): +def reshape(a, shape, new_split=None): """ Returns a tensor with the same data and number of elements as a, but with the specified shape. @@ -921,8 +921,8 @@ def reshape(a, shape, axis=None): The input tensor shape : tuple, list Shape of the new tensor - axis : int, optional - The new split axis. None denotes same axis + new_split : int, optional + The new split axis if `a` is a split DNDarray. None denotes same axis. Default : None Returns @@ -953,10 +953,10 @@ def reshape(a, shape, axis=None): raise TypeError("'a' must be a DNDarray, currently {}".format(type(a))) if not isinstance(shape, (list, tuple)): raise TypeError("shape must be list, tuple, currently {}".format(type(shape))) - # check axis parameter - if axis is None: - axis = a.split - stride_tricks.sanitize_axis(shape, axis) + # check new_split parameter + if new_split is None: + new_split = a.split + stride_tricks.sanitize_axis(shape, new_split) tdtype, tdevice = a.dtype.torch_type(), a.device.torch_device # Check the type of shape and number elements shape = stride_tricks.sanitize_shape(shape) @@ -1005,21 +1005,21 @@ def reshape_argsort_counts_displs( ) # Create new flat result tensor - _, local_shape, _ = a.comm.chunk(shape, axis) + _, local_shape, _ = a.comm.chunk(shape, new_split) data = torch.empty(local_shape, dtype=tdtype, device=tdevice).flatten() # Calculate the counts and displacements _, old_displs, _ = a.comm.counts_displs_shape(a.shape, a.split) - _, new_displs, _ = a.comm.counts_displs_shape(shape, axis) + _, new_displs, _ = a.comm.counts_displs_shape(shape, new_split) old_displs += (a.shape[a.split],) - new_displs += (shape[axis],) + new_displs += (shape[new_split],) sendsort, sendcounts, senddispls = reshape_argsort_counts_displs( - a.shape, a.lshape, old_displs, a.split, shape, new_displs, axis, a.comm + a.shape, a.lshape, old_displs, a.split, shape, new_displs, new_split, a.comm ) recvsort, recvcounts, recvdispls = reshape_argsort_counts_displs( - shape, local_shape, new_displs, axis, a.shape, old_displs, a.split, a.comm + shape, local_shape, new_displs, new_split, a.shape, old_displs, a.split, a.comm ) # rearange order @@ -1033,7 +1033,7 @@ def reshape_argsort_counts_displs( # Reshape local tensor data = data.reshape(local_shape) - return factories.array(data, dtype=a.dtype, is_split=axis, device=a.device, comm=a.comm) + return factories.array(data, dtype=a.dtype, is_split=new_split, device=a.device, comm=a.comm) def rot90(m, k=1, axes=(0, 1)): diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 52efc00259..5243a583be 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1082,7 +1082,7 @@ def test_reshape(self): a = ht.array(torch.arange(3 * 4 * 5).reshape([3, 4, 5]), split=2) result = ht.array(torch.arange(4 * 5 * 3).reshape([4, 5, 3]), split=1) - reshaped = ht.reshape(a, [4, 5, 3], axis=1) + reshaped = ht.reshape(a, [4, 5, 3], new_split=1) self.assertEqual(reshaped.size, result.size) self.assertEqual(reshaped.shape, result.shape) self.assertEqual(reshaped.split, 1) @@ -1090,7 +1090,7 @@ def test_reshape(self): a = ht.array(torch.arange(3 * 4 * 5).reshape([3, 4, 5]), split=1) result = ht.array(torch.arange(4 * 5 * 3).reshape([4 * 5, 3]), split=0) - reshaped = ht.reshape(a, [4 * 5, 3], axis=0) + reshaped = ht.reshape(a, [4 * 5, 3], new_split=0) self.assertEqual(reshaped.size, result.size) self.assertEqual(reshaped.shape, result.shape) self.assertEqual(reshaped.split, 0) @@ -1098,7 +1098,7 @@ def test_reshape(self): a = ht.array(torch.arange(3 * 4 * 5).reshape([3, 4, 5]), split=0) result = ht.array(torch.arange(4 * 5 * 3).reshape([4, 5 * 3]), split=1) - reshaped = ht.reshape(a, [4, 5 * 3], axis=1) + reshaped = ht.reshape(a, [4, 5 * 3], new_split=1) self.assertEqual(reshaped.size, result.size) self.assertEqual(reshaped.shape, result.shape) self.assertEqual(reshaped.split, 1)