Skip to content

Commit

Permalink
Merge pull request #854 from helmholtz-analytics/feature/848-moveaxis
Browse files Browse the repository at this point in the history
add moveaxis
  • Loading branch information
coquelin77 authored Aug 20, 2021
2 parents 86413ff + 638a172 commit e7c13c6
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 1 deletion.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
## Manipulations
### Manipulations
- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll`
- [#854](https://github.com/helmholtz-analytics/heat/pull/854) New Feature: `moveaxis`


# v1.1.0

Expand Down
66 changes: 66 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"flipud",
"hsplit",
"hstack",
"moveaxis",
"pad",
"ravel",
"redistribute",
Expand Down Expand Up @@ -1055,6 +1056,71 @@ def hstack(arrays: Sequence[DNDarray, ...]) -> DNDarray:
return concatenate(arrays, axis=axis)


def moveaxis(
x: DNDarray, source: Union[int, Sequence[int]], destination: Union[int, Sequence[int]]
) -> DNDarray:
"""
Moves axes at the positions in `source` to new positions.
Parameters
----------
x : DNDarray
The input array.
source : int or Sequence[int, ...]
Original positions of the axes to move. These must be unique.
destination : int or Sequence[int, ...]
Destination positions for each of the original axes. These must also be unique.
See Also
--------
~heat.core.linalg.basics.transpose
Permute the dimensions of an array.
Raises
------
TypeError
If `source` or `destination` are not ints, lists or tuples.
ValueError
If `source` and `destination` do not have the same number of elements.
Examples
--------
>>> x = ht.zeros((3, 4, 5))
>>> ht.moveaxis(x, 0, -1).shape
(4, 5, 3)
>>> ht.moveaxis(x, -1, 0).shape
(5, 3, 4)
"""
if isinstance(source, int):
source = (source,)
if isinstance(source, list):
source = tuple(source)
try:
source = stride_tricks.sanitize_axis(x.shape, source)
except TypeError:
raise TypeError("'source' must be ints, lists or tuples.")

if isinstance(destination, int):
destination = (destination,)
if isinstance(destination, list):
destination = tuple(destination)
try:
destination = stride_tricks.sanitize_axis(x.shape, destination)
except TypeError:
raise TypeError("'destination' must be ints, lists or tuples.")

if len(source) != len(destination):
raise ValueError("'source' and 'destination' must have the same number of elements.")

order = [n for n in range(x.ndim) if n not in source]

for dest, src in sorted(zip(destination, source)):
order.insert(dest, src)

return linalg.transpose(x, order)


def pad(
array: DNDarray,
pad_width: Union[int, Sequence[Sequence[int, int], ...]],
Expand Down
16 changes: 16 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,22 @@ def test_hstack(self):
res = ht.hstack((a, b))
self.assertEqual(res.shape, (24,))

def test_moveaxis(self):
a = ht.zeros((3, 4, 5))

moved = ht.moveaxis(a, 0, -1)
self.assertEquals(moved.shape, (4, 5, 3))

moved = ht.moveaxis(a, [0, 1], [-1, -2])
self.assertEquals(moved.shape, (5, 4, 3))

with self.assertRaises(TypeError):
ht.moveaxis(a, source="r", destination=3)
with self.assertRaises(TypeError):
ht.moveaxis(a, source=2, destination=3.6)
with self.assertRaises(ValueError):
ht.moveaxis(a, source=[0, 1, 2], destination=[0, 1])

def test_pad(self):
# ======================================
# test padding of non-distributed tensor
Expand Down

0 comments on commit e7c13c6

Please sign in to comment.