Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/squeeze split semantics #562

Merged
merged 10 commits into from
May 7, 2020
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

- Update documentation theme to "Read the Docs"
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Create submodule for Linear Algebra functions
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Implementated QR
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Implementated a tiling class to create Square tiles along the diagonal of a 2D matrix
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Implemented QR
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Implemented a tiling class to create Square tiles along the diagonal of a 2D matrix
- [#429](https://github.com/helmholtz-analytics/heat/pull/429) Added PyTorch Jitter to inner function of matmul for increased speed
- [#483](https://github.com/helmholtz-analytics/heat/pull/483) Bugfix: Underlying torch tensor moves to the right device on array initialisation
- [#483](https://github.com/helmholtz-analytics/heat/pull/483) Bugfix: DNDarray.cpu() changes heat device to cpu
Expand All @@ -29,6 +29,7 @@
- [#536](https://github.com/helmholtz-analytics/heat/pull/536) Getting rid of the docs folder
- [#558](https://github.com/helmholtz-analytics/heat/pull/558) `sanitize_memory_layout` assumes default memory layout of the input tensor
- [#558](https://github.com/helmholtz-analytics/heat/pull/558) Support for PyTorch 1.5.0 added
- [#562](https://github.com/helmholtz-analytics/heat/pull/562) Bugfix: split semantics of ht.squeeze()

# v0.3.0

Expand Down
69 changes: 28 additions & 41 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,9 +1196,10 @@ def squeeze(x, axis=None):
--------
squeezed : ht.DNDarray
The input tensor, but with all or a subset of the dimensions of length 1 removed.

Split semantics: see note below.

Examples:
---------
>>> import heat as ht
>>> import torch
>>> torch.manual_seed(1)
Expand Down Expand Up @@ -1226,6 +1227,19 @@ def squeeze(x, axis=None):
Traceback (most recent call last):
...
ValueError: Dimension along axis 1 is not 1 for shape (1, 3, 1, 5)

Note:
-----
Split semantics: a distributed tensor will keep its original split dimension after "squeezing",
which, depending on the squeeze axis, may result in a lower numerical 'split' value, as in:
>>> x.shape
(10, 1, 12, 13)
>>> x.split
2
>>> x.squeeze().shape
(10, 12, 13)
>>> x.squeeze().split
1
"""

# Sanitize input
Expand All @@ -1236,58 +1250,31 @@ def squeeze(x, axis=None):
if axis is not None:
if isinstance(axis, int):
dim_is_one = x.shape[axis] == 1
if isinstance(axis, tuple):
dim_is_one = bool(
factories.array(list(x.shape[dim] == 1 for dim in axis)).all()._DNDarray__array
)
axis = (axis,)
elif isinstance(axis, tuple):
dim_is_one = bool(torch.tensor(list(x.shape[dim] == 1 for dim in axis)).all())
if not dim_is_one:
raise ValueError("Dimension along axis {} is not 1 for shape {}".format(axis, x.shape))

# Local squeeze
if axis is None:
axis = tuple(i for i, dim in enumerate(x.shape) if dim == 1)
if isinstance(axis, int):
axis = (axis,)
out_lshape = tuple(x.lshape[dim] for dim in range(len(x.lshape)) if dim not in axis)

if x.split is not None and x.split in axis:
# split dimension is about to disappear, set split to None
x.resplit_(axis=None)

out_lshape = tuple(x.lshape[dim] for dim in range(x.numdims) if dim not in axis)
out_gshape = tuple(x.gshape[dim] for dim in range(x.numdims) if dim not in axis)
x_lsqueezed = x._DNDarray__array.reshape(out_lshape)

# Calculate split axis according to squeezed shape
# Calculate new split axis according to squeezed shape
if x.split is not None:
split = x.split - len(list(dim for dim in axis if dim < x.split))
else:
split = x.split

# Distributed squeeze
if x.split is not None:
if x.comm.is_distributed():
if x.split in axis:
raise ValueError(
"Cannot split AND squeeze along same axis. Split is {}, axis is {} for shape {}".format(
x.split, axis, x.shape
)
)
out_gshape = tuple(x.gshape[dim] for dim in range(len(x.gshape)) if dim not in axis)
x_gsqueezed = factories.empty(out_gshape, dtype=x.dtype)
loffset = factories.zeros(1, dtype=types.int64)
loffset.__setitem__(0, x.comm.chunk(x.gshape, x.split)[0])
displs = factories.zeros(x.comm.size, dtype=types.int64)
x.comm.Allgather(loffset, displs)

# TODO: address uneven distribution of dimensions (Allgatherv). Issue #273, #233
x.comm.Allgather(
x_lsqueezed, x_gsqueezed
) # works with evenly distributed dimensions only
return dndarray.DNDarray(
x_gsqueezed,
out_gshape,
x_lsqueezed.dtype,
split=split,
device=x.device,
comm=x.comm,
)
split = None

return dndarray.DNDarray(
x_lsqueezed, out_lshape, x.dtype, split=split, device=x.device, comm=x.comm
x_lsqueezed, out_gshape, x.dtype, split=split, device=x.device, comm=x.comm
)


Expand Down
58 changes: 36 additions & 22 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,7 @@ def test_sort(self):
ht.sort(data, axis="1")

rank = ht.MPI_WORLD.rank
ht.random.seed(1)
data = ht.random.randn(100, 1, split=0, device=ht_device)
result, _ = ht.sort(data, axis=0)
counts, _, _ = ht.get_comm().counts_displs_shape(data.gshape, axis=0)
Expand Down Expand Up @@ -1307,40 +1308,53 @@ def test_squeeze(self):
self.assertTrue((result._DNDarray__array == data._DNDarray__array.squeeze()).all())

# 4D split tensor, along the axis
# TODO: reinstate this test of uneven dimensions distribution
# after update to Allgatherv implementation (Issue #273 depending on #233)
# data = ht.array(ht.random.randn(1, 4, 5, 1), split=1)
# result = ht.squeeze(data, axis=-1)
# self.assertIsInstance(result, ht.DNDarray)
# # TODO: the following works locally but not when distributed,
# #self.assertEqual(result.dtype, ht.float32)
# #self.assertEqual(result._DNDarray__array.dtype, torch.float32)
# self.assertEqual(result.shape, (1, 12, 5))
# self.assertEqual(result.lshape, (1, 12, 5))
# self.assertEqual(result.split, 1)
data = ht.array(ht.random.randn(1, 4, 5, 1), split=1)
result = ht.squeeze(data, axis=-1)
self.assertIsInstance(result, ht.DNDarray)
self.assertEqual(result.dtype, ht.float32)
self.assertEqual(result._DNDarray__array.dtype, torch.float32)
self.assertEqual(result.shape, (1, 4, 5))
self.assertEqual(result.split, 1)

# 4D split tensor, axis = split
data = ht.array(ht.random.randn(3, 1, 5, 6), split=1)
result = ht.squeeze(data, axis=1)
self.assertIsInstance(result, ht.DNDarray)
self.assertEqual(result.dtype, ht.float32)
self.assertEqual(result._DNDarray__array.dtype, torch.float32)
self.assertEqual(result.shape, (3, 5, 6))
self.assertEqual(result.split, None)

# 4D split tensor, axis = split = last dimension
data = ht.array(ht.random.randn(3, 6, 5, 1), split=-1)
result = ht.squeeze(data, axis=-1)
self.assertIsInstance(result, ht.DNDarray)
self.assertEqual(result.dtype, ht.float32)
self.assertEqual(result._DNDarray__array.dtype, torch.float32)
self.assertEqual(result.shape, (3, 6, 5))
self.assertEqual(result.split, None)

# 3D split tensor, across the axis
size = ht.MPI_WORLD.size * 2
data = ht.triu(ht.ones((1, size, size), split=1, device=ht_device), k=1)
size = ht.MPI_WORLD.size
data = ht.triu(ht.ones((1, size * 2, size), split=1, device=ht_device), k=1)

result = ht.squeeze(data, axis=0)
self.assertIsInstance(result, ht.DNDarray)
# TODO: the following works locally but not when distributed,
# self.assertEqual(result.dtype, ht.float32)
# self.assertEqual(result._DNDarray__array.dtype, torch.float32)
self.assertEqual(result.shape, (size, size))
self.assertEqual(result.lshape, (size, size))
# self.assertEqual(result.split, None)
self.assertEqual(result.dtype, ht.float32)
self.assertEqual(result._DNDarray__array.dtype, torch.float32)
self.assertEqual(result.shape, (size * 2, size))
self.assertEqual(result.lshape, (2, size))
self.assertEqual(result.split, 0)

# check exceptions
with self.assertRaises(ValueError):
data.squeeze(axis=(0, 1))
with self.assertRaises(TypeError):
data.squeeze(axis=1.1)
with self.assertRaises(TypeError):
data.squeeze(axis="y")
with self.assertRaises(ValueError):
ht.argmin(data, axis=-4)
ht.squeeze(data, axis=-4)
with self.assertRaises(ValueError):
ht.squeeze(data, axis=1)

def test_unique(self):
size = ht.MPI_WORLD.size
Expand Down