Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

685 ravel function + 696 reshape infer dimension #690

Merged
merged 31 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4d1cc13
implement ravel
mtar Oct 14, 2020
91d3a55
Merge branch 'master' into feature/685-ravel
mtar Oct 27, 2020
816944b
add ravel function
mtar Oct 27, 2020
47ae5ec
add support for unknown dimension in reshape
mtar Nov 12, 2020
10fff38
Merge branch 'master' into feature/685-ravel
mtar Nov 12, 2020
9315c0f
test_reshape test exceptions
mtar Nov 12, 2020
fe5d8fd
update CHANGELOG
mtar Nov 12, 2020
ce36224
Update CHANGELOG.md
mtar Nov 12, 2020
e4b379b
Merge branch 'master' into feature/685-ravel
mtar Nov 18, 2020
ff03e7b
Merge branch 'feature/685-ravel' of github.com:helmholtz-analytics/he…
mtar Nov 18, 2020
52cd515
allow all arrays with split=0
mtar Nov 18, 2020
b4980af
fix test_ravel
mtar Nov 18, 2020
93910b4
fix test_ravel
mtar Nov 18, 2020
46aef99
raise coverage reshape
mtar Nov 18, 2020
32f77b7
Merge branch 'master' into feature/685-ravel
mtar Nov 18, 2020
8bc0e37
changes in ravel
mtar Dec 1, 2020
b73fece
rearrange shape calculation
mtar Dec 2, 2020
5e596e0
Merge branch 'master' into feature/685-ravel
coquelin77 Jan 11, 2021
b76b365
Merge branch 'master' into feature/685-ravel
mtar Feb 4, 2021
d8a019d
Merge branch 'master' into feature/685-ravel
mtar Feb 10, 2021
2fedbfd
Merge branch 'master' into feature/685-ravel
coquelin77 Feb 11, 2021
e5ff1e4
add reference link to flatten
mtar Feb 11, 2021
358e721
Merge branch 'master' into feature/685-ravel
coquelin77 Feb 11, 2021
e0a4d92
cast inferrred value to integer
mtar Feb 11, 2021
70ca6b7
Merge branch 'master' into feature/685-ravel
mtar Feb 11, 2021
a66f383
Merge branch 'master' into feature/685-ravel
mtar Feb 24, 2021
ce88735
Merge branch 'master' into feature/685-ravel
mtar Mar 2, 2021
f99015e
rewrite comments
mtar Mar 2, 2021
2fcbf42
Merge branch 'master' into feature/685-ravel
mtar Mar 12, 2021
07534ee
Merge branch 'master' into feature/685-ravel
coquelin77 Mar 16, 2021
b511899
Merge branch 'master' into feature/685-ravel
mtar Mar 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- [#559](https://github.com/helmholtz-analytics/heat/pull/559) Enhancement: `save_netcdf` allows naming dimensions, creating unlimited dimensions, using existing dimensions and variables, slicing
### Manipulations
- [#677](https://github.com/helmholtz-analytics/heat/pull/677) split, vsplit, dsplit, hsplit
- [#690](https://github.com/helmholtz-analytics/heat/pull/690) New feature: `ravel()`
### Statistical Functions
- [#679](https://github.com/helmholtz-analytics/heat/pull/679) New feature: ``histc()`` and ``histogram()``
### Linear Algebra
Expand All @@ -37,8 +38,9 @@
- [#716](https://github.com/helmholtz-analytics/heat/pull/716) Bugfix: Finding clusters by spectral gap fails when multiple diffs identical
- [#735](https://github.com/helmholtz-analytics/heat/pull/735) Set return type to bool in relational functions.

# v0.5.2

## Enhancements
### Manipulations
- [#690](https://github.com/helmholtz-analytics/heat/pull/690) Enhancement: reshape accepts shape arguments with one unknown dimension.
- [#706](https://github.com/helmholtz-analytics/heat/pull/706) Bug fix: prevent `__setitem__`, `__getitem__` from modifying key in place

# v0.5.1
Expand Down
27 changes: 27 additions & 0 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,10 @@ def flatten(self):
flattened : ht.DNDarray
The flattened tensor

See Also
---------
:function:`~heat.core.manipulations.flatten`

Examples
--------
>>> x = ht.array([[1,2],[3,4]])
Expand Down Expand Up @@ -2566,6 +2570,29 @@ def qr(self, tiles_per_proc=1, calc_q=True, overwrite_a=False):
self, tiles_per_proc=tiles_per_proc, calc_q=calc_q, overwrite_a=overwrite_a
)

def ravel(self):
"""
Return a flattened array with the same elements if possible.

Returns
-------
ret : DNDarray
flattened array with the same dtype as a, but with shape (a.size,).

See Also
--------
:function:`~heat.core.manipulations.ravel`

Examples
--------
>>> a = ht.ones((2,3), split=0)
>>> b = a.ravel()
>>> a[0,0] = 4
>>> b
DNDarray([4., 1., 1., 1., 1., 1.], dtype=ht.float32, device=cpu:0, split=0)
"""
return manipulations.ravel(self)

def __repr__(self) -> str:
"""
Computes a printable representation of the passed DNDarray.
Expand Down
110 changes: 100 additions & 10 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"hsplit",
"hstack",
"pad",
"ravel",
"repeat",
"reshape",
"resplit",
Expand Down Expand Up @@ -772,10 +773,16 @@ def flatten(a):
----------
a : DNDarray
array to collapse

Returns
-------
ret : DNDarray
flattened copy

See Also
--------
:function:`~heat.core.manipulations.ravel`

Examples
--------
>>> a = ht.array([[[1,2],[3,4]],[[5,6],[7,8]]])
Expand Down Expand Up @@ -1392,6 +1399,66 @@ def pad(array, pad_width, mode="constant", constant_values=0):
return padded_tensor


def ravel(a):
"""
Return a flattened view of `a` if possible. A copy is returned otherwise.

Parameters
----------
a : DNDarray
array to collapse

Returns
-------
ret : DNDarray
flattened array with the same dtype as a, but with shape (a.size,).

Notes
------
Returning a view of distributed data is only possible when `split != 0`. The returned DNDarray may be unbalanced.
Otherwise, data must be communicated among processes, and `ravel` falls back to `flatten`.


See Also
--------
:function:`~heat.core.manipulations.flatten`

Examples
--------
>>> a = ht.ones((2,3), split=0)
>>> b = ht.ravel(a)
>>> a[0,0] = 4
>>> b
DNDarray([4., 1., 1., 1., 1., 1.], dtype=ht.float32, device=cpu:0, split=0)
"""
sanitation.sanitize_in(a)

if a.split is None:
return factories.array(
torch.flatten(a._DNDarray__array),
dtype=a.dtype,
copy=False,
is_split=None,
device=a.device,
comm=a.comm,
)

# Redistribution necessary
if a.split != 0:
return flatten(a)
mtar marked this conversation as resolved.
Show resolved Hide resolved

result = factories.array(
torch.flatten(a._DNDarray__array),
dtype=a.dtype,
copy=False,
is_split=a.split,
device=a.device,
comm=a.comm,
)

return result


def repeat(a, repeats, axis=None):
"""
Creates a new DNDarray by repeating elements of array a.
Expand Down Expand Up @@ -1667,6 +1734,10 @@ def reshape(a, shape, new_split=None):
reshaped : ht.DNDarray
The DNDarray with the specified shape

See Also
--------
:function:`~heat.core.manipulations.ravel`

Raises
------
ValueError
Expand All @@ -1688,17 +1759,8 @@ def reshape(a, shape, new_split=None):
"""
if not isinstance(a, dndarray.DNDarray):
raise TypeError("'a' must be a DNDarray, currently {}".format(type(a)))
if not isinstance(shape, (list, tuple)):
raise TypeError("shape must be list, tuple, currently {}".format(type(shape)))
# check new_split parameter
if new_split is None:
new_split = a.split
stride_tricks.sanitize_axis(shape, new_split)

tdtype, tdevice = a.dtype.torch_type(), a.device.torch_device
# Check the type of shape and number elements
shape = stride_tricks.sanitize_shape(shape)
if torch.prod(torch.tensor(shape, device=tdevice)) != a.size:
raise ValueError("cannot reshape array of size {} into shape {}".format(a.size, shape))

def reshape_argsort_counts_displs(
shape1, lshape1, displs1, axis1, shape2, displs2, axis2, comm
Expand Down Expand Up @@ -1735,6 +1797,34 @@ def reshape_argsort_counts_displs(
displs[1:] = torch.cumsum(counts[:-1], dim=0)
return argsort, counts, displs

if shape == -1:
shape = (a.gnumel,)

if not isinstance(shape, (list, tuple)):
raise TypeError("shape must be list, tuple, currently {}".format(type(shape)))

# check new_split parameter
if new_split is None:
new_split = a.split
stride_tricks.sanitize_axis(shape, new_split)

# Check the type of shape and number elements
shape = stride_tricks.sanitize_shape(shape, -1)

shape = list(shape)
shape_size = torch.prod(torch.tensor(shape, dtype=torch.int, device=tdevice))

# infer unknown dimension
if shape.count(-1) > 1:
raise ValueError("too many unknown dimensions")
elif shape.count(-1) == 1:
pos = shape.index(-1)
shape[pos] = int(-(a.size / shape_size).item())
shape_size *= -shape[pos]

if shape_size != a.size:
raise ValueError("cannot reshape array of size {} into shape {}".format(a.size, shape))

# Forward to Pytorch directly
if a.split is None:
return factories.array(
Expand Down
7 changes: 5 additions & 2 deletions heat/core/stride_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,17 @@ def sanitize_axis(shape, axis):
return axis


def sanitize_shape(shape):
def sanitize_shape(shape, lval: int = 0):
"""
Verifies and normalizes the given shape.

Parameters
----------
shape : int or sequence of ints
Shape of an array.
lval : int
Lowest legal value


Returns
-------
Expand Down Expand Up @@ -154,7 +157,7 @@ def sanitize_shape(shape):
dimension = int(dimension)
if not isinstance(dimension, int):
raise TypeError("expected sequence object with length >= 0 or a single integer")
if dimension < 0:
if dimension < lval:
raise ValueError("negative dimensions are not allowed")

return shape
Expand Down
Loading