Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/591 slicing memory issues #594

Merged
merged 8 commits into from
Jun 15, 2020
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
- [#577](https://github.com/helmholtz-analytics/heat/pull/577) Add ndim property in dndarray
- [#578](https://github.com/helmholtz-analytics/heat/pull/578) Bugfix: Bad variable in reshape
- [#580](https://github.com/helmholtz-analytics/heat/pull/580) New feature: fliplr()

- [#594](https://github.com/helmholtz-analytics/heat/pull/594) New feature: Advanced indexing
- [#594](https://github.com/helmholtz-analytics/heat/pull/594) Bugfix: getitem and setitem memory consumption heavily reduced

# v0.4.0

Expand Down
6 changes: 3 additions & 3 deletions heat/core/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def diff(a, n=1, axis=-1):
axis_slice[axis] = slice(1, None, None)
axis_slice_end = [slice(None)] * len(ret.shape)
axis_slice_end[axis] = slice(None, -1, None)
ret = ret[axis_slice] - ret[axis_slice_end]
ret = ret[tuple(axis_slice)] - ret[tuple(axis_slice_end)]
return ret

size = a.comm.size
Expand Down Expand Up @@ -364,9 +364,9 @@ def diff(a, n=1, axis=-1):
recv_data.reshape(ret.lloc[axis_slice_end].shape) - ret.lloc[axis_slice_end]
)

axis_slice_end = [slice(None)] * len(a.shape)
axis_slice_end = [slice(None, None, None)] * len(a.shape)
axis_slice_end[axis] = slice(None, -1 * n, None)
ret = ret[axis_slice_end] # slice of the last element on the array (nonsense data)
ret = ret[tuple(axis_slice_end)] # slice off the last element on the array (nonsense data)
ret.balance_() # balance the array before returning
return ret

Expand Down
501 changes: 234 additions & 267 deletions heat/core/dndarray.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions heat/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def nonzero(a):

if a.ndim == 1:
lcl_nonzero = lcl_nonzero.squeeze(dim=1)
for g in range(len(gout) - 1, -1, -1):
if gout[g] == 1:
del gout[g]

return dndarray.DNDarray(
lcl_nonzero,
Expand Down
4 changes: 2 additions & 2 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def concatenate(arrays, axis=0):
arr0 = arr0.copy()
arr1 = arr1.copy()
# maps are created for where the data is and the output shape is calculated
lshape_map = factories.zeros((2, arr0.comm.size, len(arr0.gshape)), dtype=int)
lshape_map = torch.zeros((2, arr0.comm.size, len(arr0.gshape)), dtype=torch.int)
lshape_map[0, arr0.comm.rank, :] = torch.Tensor(arr0.lshape)
lshape_map[1, arr0.comm.rank, :] = torch.Tensor(arr1.lshape)
lshape_map_comm = arr0.comm.Iallreduce(MPI.IN_PLACE, lshape_map, MPI.SUM)
Expand All @@ -211,7 +211,7 @@ def concatenate(arrays, axis=0):
out_shape = tuple(arr0_shape)

# the chunk map is used for determine how much data should be on each process
chunk_map = factories.zeros((arr0.comm.size, len(arr0.gshape)), dtype=int)
chunk_map = torch.zeros((arr0.comm.size, len(arr0.gshape)), dtype=torch.int)
_, _, chk = arr0.comm.chunk(out_shape, s0 if s0 is not None else s1)
for i in range(len(out_shape)):
chunk_map[arr0.comm.rank, i] = chk[i].stop - chk[i].start
Expand Down
13 changes: 10 additions & 3 deletions heat/core/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,14 @@ def __reduce_op(x, partial_op, reduction_op, neutral=None, **kwargs):
if 0 in x.lshape and (axis is None or (x.split in axis)):
if neutral is None:
neutral = float("nan")
neutral_shape = x.lshape[:split] + (1,) + x.lshape[split + 1 :]
partial = torch.full(neutral_shape, fill_value=neutral, dtype=x._DNDarray__array.dtype)
neutral_shape = x.gshape[:split] + (1,) + x.gshape[split + 1 :]
partial = torch.full(
neutral_shape,
fill_value=neutral,
dtype=x.dtype.torch_type(),
device=x.device.torch_device,
)

else:
partial = x._DNDarray__array

Expand All @@ -406,7 +412,8 @@ def __reduce_op(x, partial_op, reduction_op, neutral=None, **kwargs):
lshape_losedim = (partial.shape[0],) + lshape_losedim
if 0 not in axis and partial.shape[0] != x.lshape[0]:
lshape_losedim = (partial.shape[0],) + lshape_losedim[1:]
partial = partial.reshape(lshape_losedim)
if len(lshape_losedim) > 0:
partial = partial.reshape(lshape_losedim)

# Check shape of output buffer, if any
if out is not None and out.shape != output_shape:
Expand Down
9 changes: 5 additions & 4 deletions heat/core/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,11 @@ def average(x, axis=None, weights=None, returned=False):
)
wgt_slice = [slice(None) if dim == axis else 0 for dim in list(range(x.ndim))]
wgt_split = None if weights.split is None else axis
wgt = factories.empty(wgt_lshape, dtype=weights.dtype, device=x.device)
wgt._DNDarray__array[wgt_slice] = weights._DNDarray__array
wgt = factories.array(wgt._DNDarray__array, is_split=wgt_split)
wgt = torch.empty(
wgt_lshape, dtype=weights.dtype.torch_type(), device=x.device.torch_device
)
wgt[wgt_slice] = weights._DNDarray__array
wgt = factories.array(wgt, is_split=wgt_split)
else:
if x.comm.is_distributed():
if x.split is not None and weights.split != x.split and weights.ndim != 1:
Expand All @@ -345,7 +347,6 @@ def average(x, axis=None, weights=None, returned=False):
wgt._DNDarray__array = weights._DNDarray__array

cumwgt = wgt.sum(axis=axis)

if logical.any(cumwgt == 0.0):
raise ZeroDivisionError("Weights sum to zero, can't be normalized")

Expand Down
28 changes: 28 additions & 0 deletions heat/core/stride_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,31 @@ def sanitize_shape(shape):
raise ValueError("negative dimensions are not allowed")

return shape


def sanitize_slice(sl, max_dim) -> slice:
"""
Remove None-types from a slice

Parameters
----------
sl : slice
coquelin77 marked this conversation as resolved.
Show resolved Hide resolved
max_dim : int
maximum index for the given slice

Raises
------
TypeError
if sl is not a slice
"""
if not isinstance(sl, slice):
raise TypeError("This function is only for slices!")
new_sl = [None] * 3
coquelin77 marked this conversation as resolved.
Show resolved Hide resolved
new_sl[0] = 0 if sl.start is None else sl.start
if new_sl[0] < 0:
new_sl[0] += max_dim
new_sl[1] = max_dim if sl.stop is None else sl.stop
if new_sl[1] < 0:
new_sl[1] += max_dim
new_sl[2] = 1 if sl.step is None else sl.step
return slice(new_sl[0], new_sl[1], new_sl[2])
coquelin77 marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 3 additions & 2 deletions heat/core/tests/test_arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,9 @@ def test_diff(self):
# loop to 3 for the number of times to do the diff
for nl in range(1, 4):
# only generating the number once and then
lp_array = ht.manipulations.resplit(ht_array[arb_slice], sp)
np_array = ht_array[arb_slice].numpy()
tup_arb = tuple(arb_slice)
lp_array = ht.manipulations.resplit(ht_array[tup_arb], sp)
np_array = ht_array[tup_arb].numpy()

ht_diff = ht.diff(lp_array, n=nl, axis=ax)
np_diff = ht.array(np.diff(np_array, n=nl, axis=ax), device=ht_device)
Expand Down
7 changes: 6 additions & 1 deletion heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,7 @@ def test_setitem_getitem(self):

a = ht.zeros((13, 5), split=0, device=ht_device)
a[-1] = 1
# print('here', a)
coquelin77 marked this conversation as resolved.
Show resolved Hide resolved
b = a[-1]
self.assertTrue((b == 1).all())
self.assertEqual(b.dtype, ht.float32)
Expand Down Expand Up @@ -961,8 +962,12 @@ def test_setitem_getitem(self):
# setting with heat tensor
a = ht.zeros((4, 5), split=1, device=ht_device)
a[1, 0:4] = ht.arange(4, device=ht_device)
# print(a)
coquelin77 marked this conversation as resolved.
Show resolved Hide resolved
# print(a[1, 2])
for c, i in enumerate(range(4)):
self.assertEqual(a[1, c], i)
b = a[1, c]
if b._DNDarray__array.numel() > 0:
self.assertEqual(b.item(), i)

# setting with torch tensor
a = ht.zeros((4, 5), split=1, device=ht_device)
Expand Down
2 changes: 1 addition & 1 deletion heat/core/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_where(self):
[[0.0, 1.0, 2.0], [0.0, 2.0, -1.0], [0.0, 3.0, -1.0]], split=0, device=ht_device
)
wh = ht.where(a < 4.0, a, -1)
self.assertTrue(ht.all(wh[ht.nonzero(a >= 4)], -1))
self.assertTrue(ht.all(wh[ht.nonzero(a >= 4)] == -1))
self.assertTrue(ht.equal(wh, res))
self.assertEqual(wh.gshape, (3, 3))
self.assertEqual(wh.dtype, ht.float)
Expand Down
15 changes: 15 additions & 0 deletions heat/core/tests/test_stride_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,18 @@ def test_sanitize_shape(self):
ht.core.stride_tricks.sanitize_shape(1.0)
with self.assertRaises(TypeError):
ht.core.stride_tricks.sanitize_shape((1, 1.0))

def test_sanitize_slice(self):
test_slice = slice(None, None, None)
ret_slice = ht.core.stride_tricks.sanitize_slice(test_slice, 100)
self.assertEqual(ret_slice.start, 0)
self.assertEqual(ret_slice.stop, 100)
self.assertEqual(ret_slice.step, 1)
test_slice = slice(-50, -5, 2)
ret_slice = ht.core.stride_tricks.sanitize_slice(test_slice, 100)
self.assertEqual(ret_slice.start, 50)
self.assertEqual(ret_slice.stop, 95)
self.assertEqual(ret_slice.step, 2)

with self.assertRaises(TypeError):
ht.core.stride_tricks.sanitize_slice("test_slice", 100)
5 changes: 1 addition & 4 deletions heat/naive_bayes/gaussianNB.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def __partial_fit(self, X, y, classes=None, _refit=False, sample_weight=None):
(classes._DNDarray__array, y_i._DNDarray__array.unsqueeze(0))
)
i = torch.argsort(classes_ext)[-1].item()
where_y_i = ht.where(y == y_i)._DNDarray__array.tolist()
where_y_i = ht.where(y == y_i)
X_i = X[where_y_i, :]

if sample_weight is not None:
Expand All @@ -366,11 +366,9 @@ def __partial_fit(self, X, y, classes=None, _refit=False, sample_weight=None):
else:
sw_i = None
N_i = X_i.shape[0]

new_theta, new_sigma = self.__update_mean_variance(
self.class_count_[i], self.theta_[i, :], self.sigma_[i, :], X_i, sw_i
)

self.theta_[i, :] = new_theta
self.sigma_[i, :] = new_sigma
self.class_count_[i] += N_i
Expand Down Expand Up @@ -400,7 +398,6 @@ def __joint_log_likelihood(self, X):
n_ij = -0.5 * ht.sum(ht.log(2.0 * ht.pi * self.sigma_[i, :]))
n_ij -= 0.5 * ht.sum(((X - self.theta_[i, :]) ** 2) / (self.sigma_[i, :]), 1)
joint_log_likelihood[:, i] = jointi + n_ij

return joint_log_likelihood

def logsumexp(self, a, axis=None, b=None, keepdim=False, return_sign=False):
Expand Down