diff --git a/CHANGELOG.md b/CHANGELOG.md index 51bd0a9c2e..5fcf4015f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,9 @@ ## Bug Fixes - [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension +# Feature Additions +## Manipulations +- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll` # v1.1.0 diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index ef0e3f305c..63d6a316ba 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -42,6 +42,7 @@ "repeat", "reshape", "resplit", + "roll", "rot90", "row_stack", "shape", @@ -1908,6 +1909,173 @@ def reshape_argsort_counts_displs( DNDarray.reshape.__doc__ = reshape.__doc__ +def roll( + x: DNDarray, shift: Union[int, Tuple[int]], axis: Optional[Union[int, Tuple[int]]] = None +) -> DNDarray: + """ + Rolls array elements along a specified axis. Array elements that roll beyond the last position are re-introduced at the first position. + Array elements that roll beyond the first position are re-introduced at the last position. + + Parameters + ---------- + x : DNDarray + input array + shift : Union[int, Tuple[int, ...]] + number of places by which the elements are shifted. If 'shift' is a tuple, then 'axis' must be a tuple of the same size, and each of + the given axes is shifted by the corrresponding element in 'shift'. If 'shift' is an `int` and 'axis' a `tuple`, then the same shift + is used for all specified axes. + axis : Optional[Union[int, Tuple[int, ...]]] + axis (or axes) along which elements to shift. If 'axis' is `None`, the array is flattened, shifted, and then restored to its original shape. + Default: `None`. + + Raises + ------ + TypeError + If 'shift' or 'axis' is not of type `int`, `list` or `tuple`. + ValueError + If 'shift' and 'axis' are tuples with different sizes. + + Examples + -------- + >>> a = ht.arange(20).reshape((4,5)) + >>> a + DNDarray([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19]], dtype=ht.int32, device=cpu:0, split=None) + >>> ht.roll(a, 1) + DNDarray([[19, 0, 1, 2, 3], + [ 4, 5, 6, 7, 8], + [ 9, 10, 11, 12, 13], + [14, 15, 16, 17, 18]], dtype=ht.int32, device=cpu:0, split=None) + >>> ht.roll(a, -1, 0) + DNDarray([[ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [ 0, 1, 2, 3, 4]], dtype=ht.int32, device=cpu:0, split=None) + """ + sanitation.sanitize_in(x) + + if axis is None: + return roll(x.flatten(), shift, 0).reshape(x.shape, new_split=x.split) + + # inputs are ints + if isinstance(shift, int): + if isinstance(axis, int): + if x.split is not None and (axis == x.split or (axis + x.ndim) == x.split): + # roll along split axis + size = x.comm.Get_size() + rank = x.comm.Get_rank() + + # local elements along axis: + lshape_map = x.create_lshape_map(force_check=False)[:, x.split] + cumsum_map = torch.cumsum(lshape_map, dim=0) # cumulate along axis + indices = torch.arange(size, device=x.device.torch_device) + # NOTE Can be removed when min version>=1.9 + if "1.7." in torch.__version__ or "1.8." in torch.__version__: + lshape_map = lshape_map.to(torch.int64) + index_map = torch.repeat_interleave(indices, lshape_map) # index -> process + + # compute index positions + index_old = torch.arange(lshape_map[rank], device=x.device.torch_device) + if rank > 0: + index_old += cumsum_map[rank - 1] + + send_index = (index_old + shift) % x.gshape[x.split] + recv_index = (index_old - shift) % x.gshape[x.split] + + # exchange arrays + recv = torch.empty_like(x.larray) + recv_splits = torch.split(recv, 1, dim=x.split) + recv_requests = [None for i in range(x.lshape[x.split])] + + for i in range(x.lshape[x.split]): + recv_requests[i] = x.comm.Irecv( + recv_splits[i], index_map[recv_index[i]], index_old[i] + ) + + send_splits = torch.split(x.larray, 1, dim=x.split) + send_requests = [None for i in range(x.lshape[x.split])] + + for i in range(x.lshape[x.split]): + send_requests[i] = x.comm.Isend( + send_splits[i], index_map[send_index[i]], send_index[i] + ) + + for i in range(x.lshape[x.split]): + recv_requests[i].Wait() + for i in range(x.lshape[x.split]): + send_requests[i].Wait() + + return DNDarray(recv, x.gshape, x.dtype, x.split, x.device, x.comm, x.balanced) + + else: # pytorch does not support int / sequence combo at the time, make shift a list instead + try: + axis = sanitation.sanitize_sequence(axis) + except TypeError: + raise TypeError("axis must be a int, list or a tuple, got {}".format(type(axis))) + + shift = [shift] * len(axis) + + return roll(x, shift, axis) + + else: # input must be tuples now + try: + shift = sanitation.sanitize_sequence(shift) + except TypeError: + raise TypeError("shift must be an integer, list or a tuple, got {}".format(type(shift))) + + try: + axis = sanitation.sanitize_sequence(axis) + except TypeError: + raise TypeError("axis must be an integer, list or a tuple, got {}".format(type(axis))) + + if len(shift) != len(axis): + raise ValueError( + "shift and axis length must be the same, got {} and {}".format( + len(shift), len(axis) + ) + ) + + for i in range(len(shift)): + if not isinstance(shift[i], int): + raise TypeError( + "Element {} in shift is not an integer, got {}".format(i, type(shift[i])) + ) + if not isinstance(axis[i], int): + raise TypeError( + "Element {} in axis is not an integer, got {}".format(i, type(axis[i])) + ) + + if x.split is not None and (x.split in axis or (x.split - x.ndim) in axis): + # remove split axis elements + shift_split = 0 + for y in (x.split, x.split - x.ndim): + idx = [i for i in range(len(axis)) if axis[i] == y] + for i in idx: + shift_split += shift[i] + for i in reversed(idx): + axis.remove(y) + del shift[i] + + # compute new array along split axis + x = roll(x, shift_split, x.split) + if len(axis) == 0: + return x + + # use PyTorch for all other axes + rolled = torch.roll(x.larray, shift, axis) + return DNDarray( + rolled, + gshape=x.shape, + dtype=x.dtype, + split=x.split, + device=x.device, + comm=x.comm, + balanced=x.balanced, + ) + + def rot90(m: DNDarray, k: int = 1, axes: Sequence[int, int] = (0, 1)) -> DNDarray: """ Rotate an array by 90 degrees in the plane specified by `axes`. diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 17b10c2983..5a6d139559 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2306,6 +2306,200 @@ def test_reshape(self): with self.assertRaises(TypeError): ht.reshape(ht.zeros((4, 3)), (3.4, 3.2)) + def test_roll(self): + # no split + # vector + a = ht.arange(5) + rolled = ht.roll(a, 1) + compare = ht.array([4, 0, 1, 2, 3]) + + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(ht.equal(rolled, compare)) + + rolled = ht.roll(a, -1) + compare = ht.array([1, 2, 3, 4, 0]) + + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(ht.equal(rolled, compare)) + + # matrix + a = ht.arange(20.0).reshape((4, 5)) + + rolled = ht.roll(a, -1) + compare = torch.roll(a.larray, -1) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(torch.equal(rolled.larray, compare)) + + rolled = ht.roll(a, 1, 0) + compare = torch.roll(a.larray, 1, 0) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(torch.equal(rolled.larray, compare)) + + rolled = ht.roll(a, -2, (0, 1)) + compare = np.roll(a.larray.cpu().numpy(), -2, (0, 1)) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.larray.cpu().numpy(), compare)) + + rolled = ht.roll(a, (1, 2, 1), (0, 1, -2)) + compare = torch.roll(a.larray, (1, 2, 1), (0, 1, -2)) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(torch.equal(rolled.larray, compare)) + + # split + # vector + a = ht.arange(5, dtype=ht.uint8, split=0) + rolled = ht.roll(a, 1) + compare = ht.array([4, 0, 1, 2, 3], dtype=ht.uint8, split=0) + + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(ht.equal(rolled, compare)) + + rolled = ht.roll(a, -1) + compare = ht.array([1, 2, 3, 4, 0], ht.uint8, split=0) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(ht.equal(rolled, compare)) + + # matrix + a = ht.arange(20).reshape((4, 5), dtype=ht.int16, new_split=0) + + rolled = ht.roll(a, -1) + compare = np.roll(a.numpy(), -1) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + rolled = ht.roll(a, 1, 0) + compare = np.roll(a.numpy(), 1, 0) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + rolled = ht.roll(a, -2, (0, 1)) + compare = np.roll(a.numpy(), -2, (0, 1)) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + rolled = ht.roll(a, (1, 2, 1), (0, 1, -2)) + compare = np.roll(a.numpy(), (1, 2, 1), (0, 1, -2)) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + a = ht.arange(20, dtype=ht.complex64).reshape((4, 5), new_split=1) + + rolled = ht.roll(a, -1) + compare = np.roll(a.numpy(), -1) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + rolled = ht.roll(a, 1, 0) + compare = np.roll(a.numpy(), 1, 0) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + rolled = ht.roll(a, -2, [0, 1]) + compare = np.roll(a.numpy(), -2, [0, 1]) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + rolled = ht.roll(a, [1, 2, 1], [0, 1, -2]) + compare = np.roll(a.numpy(), [1, 2, 1], [0, 1, -2]) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + # added 3D test, only a quick test for functionality + a = ht.arange(4 * 5 * 6, dtype=ht.complex64).reshape((4, 5, 6), new_split=2) + + rolled = ht.roll(a, -1) + compare = np.roll(a.numpy(), -1) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + rolled = ht.roll(a, 1, 0) + compare = np.roll(a.numpy(), 1, 0) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + rolled = ht.roll(a, -2, [0, 1]) + compare = np.roll(a.numpy(), -2, [0, 1]) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + rolled = ht.roll(a, [1, 2, 1], [0, 1, -2]) + compare = np.roll(a.numpy(), [1, 2, 1], [0, 1, -2]) + self.assertEqual(rolled.device, a.device) + self.assertEqual(rolled.size, a.size) + self.assertEqual(rolled.dtype, a.dtype) + self.assertEqual(rolled.split, a.split) + self.assertTrue(np.array_equal(rolled.numpy(), compare)) + + with self.assertRaises(TypeError): + ht.roll(a, 1.0, 0) + with self.assertRaises(TypeError): + ht.roll(a, 1, 1.0) + with self.assertRaises(TypeError): + ht.roll(a, 1, (1.0, 0.0)) + with self.assertRaises(TypeError): + ht.roll(a, (-1, 1), 0.0) + with self.assertRaises(TypeError): + ht.roll(a, (-1.0, 1.0), (0, 0)) + with self.assertRaises(ValueError): + ht.roll(a, [1, 1, 1], [0, 0]) + def test_rot90(self): size = ht.MPI_WORLD.size m = ht.arange(size ** 3, dtype=ht.int).reshape((size, size, size))