Skip to content

Commit

Permalink
Merge branch 'master' into features/553-random_aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
coquelin77 authored Sep 7, 2020
2 parents e0a3ab6 + 6f83026 commit f4b62dd
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
- [#653](https://github.com/helmholtz-analytics/heat/pull/653) Printing above threshold gathers the data without a buffer now
- [#653](https://github.com/helmholtz-analytics/heat/pull/653) Bugfixes: Update unittests argmax & argmin + force index order in mpi_argmax & mpi_argmin. Add device parameter for tensor creation in dndarray.get_halo().
- [#664](https://github.com/helmholtz-analytics/heat/pull/664) New feature / enhancement: `random.random_sample`, `random.random`, `random.sample`, `random.ranf`, `random.random_integer`
- [#667](https://github.com/helmholtz-analytics/heat/pull/667) Enhancement `reshape`: rename axis parameter

# v0.4.0

Expand Down
26 changes: 13 additions & 13 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ def hstack(tup):
return concatenate(tup, axis=axis)


def reshape(a, shape, axis=None):
def reshape(a, shape, new_split=None):
"""
Returns a tensor with the same data and number of elements as a, but with the specified shape.
Expand All @@ -921,8 +921,8 @@ def reshape(a, shape, axis=None):
The input tensor
shape : tuple, list
Shape of the new tensor
axis : int, optional
The new split axis. None denotes same axis
new_split : int, optional
The new split axis if `a` is a split DNDarray. None denotes same axis.
Default : None
Returns
Expand Down Expand Up @@ -953,10 +953,10 @@ def reshape(a, shape, axis=None):
raise TypeError("'a' must be a DNDarray, currently {}".format(type(a)))
if not isinstance(shape, (list, tuple)):
raise TypeError("shape must be list, tuple, currently {}".format(type(shape)))
# check axis parameter
if axis is None:
axis = a.split
stride_tricks.sanitize_axis(shape, axis)
# check new_split parameter
if new_split is None:
new_split = a.split
stride_tricks.sanitize_axis(shape, new_split)
tdtype, tdevice = a.dtype.torch_type(), a.device.torch_device
# Check the type of shape and number elements
shape = stride_tricks.sanitize_shape(shape)
Expand Down Expand Up @@ -1005,21 +1005,21 @@ def reshape_argsort_counts_displs(
)

# Create new flat result tensor
_, local_shape, _ = a.comm.chunk(shape, axis)
_, local_shape, _ = a.comm.chunk(shape, new_split)
data = torch.empty(local_shape, dtype=tdtype, device=tdevice).flatten()

# Calculate the counts and displacements
_, old_displs, _ = a.comm.counts_displs_shape(a.shape, a.split)
_, new_displs, _ = a.comm.counts_displs_shape(shape, axis)
_, new_displs, _ = a.comm.counts_displs_shape(shape, new_split)

old_displs += (a.shape[a.split],)
new_displs += (shape[axis],)
new_displs += (shape[new_split],)

sendsort, sendcounts, senddispls = reshape_argsort_counts_displs(
a.shape, a.lshape, old_displs, a.split, shape, new_displs, axis, a.comm
a.shape, a.lshape, old_displs, a.split, shape, new_displs, new_split, a.comm
)
recvsort, recvcounts, recvdispls = reshape_argsort_counts_displs(
shape, local_shape, new_displs, axis, a.shape, old_displs, a.split, a.comm
shape, local_shape, new_displs, new_split, a.shape, old_displs, a.split, a.comm
)

# rearange order
Expand All @@ -1033,7 +1033,7 @@ def reshape_argsort_counts_displs(
# Reshape local tensor
data = data.reshape(local_shape)

return factories.array(data, dtype=a.dtype, is_split=axis, device=a.device, comm=a.comm)
return factories.array(data, dtype=a.dtype, is_split=new_split, device=a.device, comm=a.comm)


def rot90(m, k=1, axes=(0, 1)):
Expand Down
6 changes: 3 additions & 3 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,23 +1082,23 @@ def test_reshape(self):

a = ht.array(torch.arange(3 * 4 * 5).reshape([3, 4, 5]), split=2)
result = ht.array(torch.arange(4 * 5 * 3).reshape([4, 5, 3]), split=1)
reshaped = ht.reshape(a, [4, 5, 3], axis=1)
reshaped = ht.reshape(a, [4, 5, 3], new_split=1)
self.assertEqual(reshaped.size, result.size)
self.assertEqual(reshaped.shape, result.shape)
self.assertEqual(reshaped.split, 1)
self.assertTrue(ht.equal(reshaped, result))

a = ht.array(torch.arange(3 * 4 * 5).reshape([3, 4, 5]), split=1)
result = ht.array(torch.arange(4 * 5 * 3).reshape([4 * 5, 3]), split=0)
reshaped = ht.reshape(a, [4 * 5, 3], axis=0)
reshaped = ht.reshape(a, [4 * 5, 3], new_split=0)
self.assertEqual(reshaped.size, result.size)
self.assertEqual(reshaped.shape, result.shape)
self.assertEqual(reshaped.split, 0)
self.assertTrue(ht.equal(reshaped, result))

a = ht.array(torch.arange(3 * 4 * 5).reshape([3, 4, 5]), split=0)
result = ht.array(torch.arange(4 * 5 * 3).reshape([4, 5 * 3]), split=1)
reshaped = ht.reshape(a, [4, 5 * 3], axis=1)
reshaped = ht.reshape(a, [4, 5 * 3], new_split=1)
self.assertEqual(reshaped.size, result.size)
self.assertEqual(reshaped.shape, result.shape)
self.assertEqual(reshaped.split, 1)
Expand Down

0 comments on commit f4b62dd

Please sign in to comment.