Skip to content

Commit

Permalink
Merge pull request #562 from helmholtz-analytics/bug/squeeze-split-se…
Browse files Browse the repository at this point in the history
…mantics

Bug/squeeze split semantics
  • Loading branch information
coquelin77 authored May 7, 2020
2 parents 0784ed6 + 6ca24b8 commit bd242b0
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 65 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

- Update documentation theme to "Read the Docs"
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Create submodule for Linear Algebra functions
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Implementated QR
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Implementated a tiling class to create Square tiles along the diagonal of a 2D matrix
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Implemented QR
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Implemented a tiling class to create Square tiles along the diagonal of a 2D matrix
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Added PyTorch Jitter to inner function of matmul for increased speed
- [#483](https://github.com/helmholtz-analytics/heat/pull/483) Bugfix: Underlying torch tensor moves to the right device on array initialisation
- [#483](https://github.com/helmholtz-analytics/heat/pull/483) Bugfix: DNDarray.cpu() changes heat device to cpu
Expand All @@ -29,6 +29,7 @@
- [#536](https://github.com/helmholtz-analytics/heat/pull/536) Getting rid of the docs folder
- [#558](https://github.com/helmholtz-analytics/heat/pull/558) `sanitize_memory_layout` assumes default memory layout of the input tensor
- [#558](https://github.com/helmholtz-analytics/heat/pull/558) Support for PyTorch 1.5.0 added
- [#562](https://github.com/helmholtz-analytics/heat/pull/562) Bugfix: split semantics of ht.squeeze()

# v0.3.0

Expand Down
69 changes: 28 additions & 41 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,9 +1196,10 @@ def squeeze(x, axis=None):
--------
squeezed : ht.DNDarray
The input tensor, but with all or a subset of the dimensions of length 1 removed.
Split semantics: see note below.
Examples:
---------
>>> import heat as ht
>>> import torch
>>> torch.manual_seed(1)
Expand Down Expand Up @@ -1226,6 +1227,19 @@ def squeeze(x, axis=None):
Traceback (most recent call last):
...
ValueError: Dimension along axis 1 is not 1 for shape (1, 3, 1, 5)
Note:
-----
Split semantics: a distributed tensor will keep its original split dimension after "squeezing",
which, depending on the squeeze axis, may result in a lower numerical 'split' value, as in:
>>> x.shape
(10, 1, 12, 13)
>>> x.split
2
>>> x.squeeze().shape
(10, 12, 13)
>>> x.squeeze().split
1
"""

# Sanitize input
Expand All @@ -1236,58 +1250,31 @@ def squeeze(x, axis=None):
if axis is not None:
if isinstance(axis, int):
dim_is_one = x.shape[axis] == 1
if isinstance(axis, tuple):
dim_is_one = bool(
factories.array(list(x.shape[dim] == 1 for dim in axis)).all()._DNDarray__array
)
axis = (axis,)
elif isinstance(axis, tuple):
dim_is_one = bool(torch.tensor(list(x.shape[dim] == 1 for dim in axis)).all())
if not dim_is_one:
raise ValueError("Dimension along axis {} is not 1 for shape {}".format(axis, x.shape))

# Local squeeze
if axis is None:
axis = tuple(i for i, dim in enumerate(x.shape) if dim == 1)
if isinstance(axis, int):
axis = (axis,)
out_lshape = tuple(x.lshape[dim] for dim in range(len(x.lshape)) if dim not in axis)

if x.split is not None and x.split in axis:
# split dimension is about to disappear, set split to None
x.resplit_(axis=None)

out_lshape = tuple(x.lshape[dim] for dim in range(x.numdims) if dim not in axis)
out_gshape = tuple(x.gshape[dim] for dim in range(x.numdims) if dim not in axis)
x_lsqueezed = x._DNDarray__array.reshape(out_lshape)

# Calculate split axis according to squeezed shape
# Calculate new split axis according to squeezed shape
if x.split is not None:
split = x.split - len(list(dim for dim in axis if dim < x.split))
else:
split = x.split

# Distributed squeeze
if x.split is not None:
if x.comm.is_distributed():
if x.split in axis:
raise ValueError(
"Cannot split AND squeeze along same axis. Split is {}, axis is {} for shape {}".format(
x.split, axis, x.shape
)
)
out_gshape = tuple(x.gshape[dim] for dim in range(len(x.gshape)) if dim not in axis)
x_gsqueezed = factories.empty(out_gshape, dtype=x.dtype)
loffset = factories.zeros(1, dtype=types.int64)
loffset.__setitem__(0, x.comm.chunk(x.gshape, x.split)[0])
displs = factories.zeros(x.comm.size, dtype=types.int64)
x.comm.Allgather(loffset, displs)

# TODO: address uneven distribution of dimensions (Allgatherv). Issue #273, #233
x.comm.Allgather(
x_lsqueezed, x_gsqueezed
) # works with evenly distributed dimensions only
return dndarray.DNDarray(
x_gsqueezed,
out_gshape,
x_lsqueezed.dtype,
split=split,
device=x.device,
comm=x.comm,
)
split = None

return dndarray.DNDarray(
x_lsqueezed, out_lshape, x.dtype, split=split, device=x.device, comm=x.comm
x_lsqueezed, out_gshape, x.dtype, split=split, device=x.device, comm=x.comm
)


Expand Down
58 changes: 36 additions & 22 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,7 @@ def test_sort(self):
ht.sort(data, axis="1")

rank = ht.MPI_WORLD.rank
ht.random.seed(1)
data = ht.random.randn(100, 1, split=0, device=ht_device)
result, _ = ht.sort(data, axis=0)
counts, _, _ = ht.get_comm().counts_displs_shape(data.gshape, axis=0)
Expand Down Expand Up @@ -1307,40 +1308,53 @@ def test_squeeze(self):
self.assertTrue((result._DNDarray__array == data._DNDarray__array.squeeze()).all())

# 4D split tensor, along the axis
# TODO: reinstate this test of uneven dimensions distribution
# after update to Allgatherv implementation (Issue #273 depending on #233)
# data = ht.array(ht.random.randn(1, 4, 5, 1), split=1)
# result = ht.squeeze(data, axis=-1)
# self.assertIsInstance(result, ht.DNDarray)
# # TODO: the following works locally but not when distributed,
# #self.assertEqual(result.dtype, ht.float32)
# #self.assertEqual(result._DNDarray__array.dtype, torch.float32)
# self.assertEqual(result.shape, (1, 12, 5))
# self.assertEqual(result.lshape, (1, 12, 5))
# self.assertEqual(result.split, 1)
data = ht.array(ht.random.randn(1, 4, 5, 1), split=1)
result = ht.squeeze(data, axis=-1)
self.assertIsInstance(result, ht.DNDarray)
self.assertEqual(result.dtype, ht.float32)
self.assertEqual(result._DNDarray__array.dtype, torch.float32)
self.assertEqual(result.shape, (1, 4, 5))
self.assertEqual(result.split, 1)

# 4D split tensor, axis = split
data = ht.array(ht.random.randn(3, 1, 5, 6), split=1)
result = ht.squeeze(data, axis=1)
self.assertIsInstance(result, ht.DNDarray)
self.assertEqual(result.dtype, ht.float32)
self.assertEqual(result._DNDarray__array.dtype, torch.float32)
self.assertEqual(result.shape, (3, 5, 6))
self.assertEqual(result.split, None)

# 4D split tensor, axis = split = last dimension
data = ht.array(ht.random.randn(3, 6, 5, 1), split=-1)
result = ht.squeeze(data, axis=-1)
self.assertIsInstance(result, ht.DNDarray)
self.assertEqual(result.dtype, ht.float32)
self.assertEqual(result._DNDarray__array.dtype, torch.float32)
self.assertEqual(result.shape, (3, 6, 5))
self.assertEqual(result.split, None)

# 3D split tensor, across the axis
size = ht.MPI_WORLD.size * 2
data = ht.triu(ht.ones((1, size, size), split=1, device=ht_device), k=1)
size = ht.MPI_WORLD.size
data = ht.triu(ht.ones((1, size * 2, size), split=1, device=ht_device), k=1)

result = ht.squeeze(data, axis=0)
self.assertIsInstance(result, ht.DNDarray)
# TODO: the following works locally but not when distributed,
# self.assertEqual(result.dtype, ht.float32)
# self.assertEqual(result._DNDarray__array.dtype, torch.float32)
self.assertEqual(result.shape, (size, size))
self.assertEqual(result.lshape, (size, size))
# self.assertEqual(result.split, None)
self.assertEqual(result.dtype, ht.float32)
self.assertEqual(result._DNDarray__array.dtype, torch.float32)
self.assertEqual(result.shape, (size * 2, size))
self.assertEqual(result.lshape, (2, size))
self.assertEqual(result.split, 0)

# check exceptions
with self.assertRaises(ValueError):
data.squeeze(axis=(0, 1))
with self.assertRaises(TypeError):
data.squeeze(axis=1.1)
with self.assertRaises(TypeError):
data.squeeze(axis="y")
with self.assertRaises(ValueError):
ht.argmin(data, axis=-4)
ht.squeeze(data, axis=-4)
with self.assertRaises(ValueError):
ht.squeeze(data, axis=1)

def test_unique(self):
size = ht.MPI_WORLD.size
Expand Down

0 comments on commit bd242b0

Please sign in to comment.