Skip to content

Commit

Permalink
Merge branch 'master' into feature/685-ravel
Browse files Browse the repository at this point in the history
  • Loading branch information
mtar authored Mar 22, 2021
2 parents 07534ee + 442c271 commit b511899
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 111 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
## Bug fixes
- [#709](https://github.com/helmholtz-analytics/heat/pull/709) Set the encoding for README.md in setup.py explicitly.
- [#716](https://github.com/helmholtz-analytics/heat/pull/716) Bugfix: Finding clusters by spectral gap fails when multiple diffs identical
- [#732](https://github.com/helmholtz-analytics/heat/pull/732) Corrected logic in `DNDarray.__getitem__` to produce the correct split axis
- [#734](https://github.com/helmholtz-analytics/heat/pull/734) Fix division by zero error in `__local_op` with out != None on empty local arrays.
- [#735](https://github.com/helmholtz-analytics/heat/pull/735) Set return type to bool in relational functions.

Expand Down
221 changes: 110 additions & 111 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,122 +1611,121 @@ def __getitem__(self, key):
self.comm,
self.balanced,
)

# else: (DNDarray is distributed)
rank = self.comm.rank
ends = []
for pr in range(self.comm.size):
_, _, e = self.comm.chunk(self.shape, self.split, rank=pr)
ends.append(e[self.split].stop - e[self.split].start)
ends = torch.tensor(ends, device=self.device.torch_device)
chunk_ends = ends.cumsum(dim=0)
chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device)
chunk_start = chunk_starts[rank]
chunk_end = chunk_ends[rank]
arr = torch.tensor([], device=self.device.torch_device)
# all keys should be tuples here
# handle the dimensional reduction for integers
int_locs = [isinstance(it, int) for it in key]
ints = sum(int_locs)
if ints > 0:
ints_before_split = sum(int_locs[: self.split + 1])
new_split = self.split - ints_before_split
if new_split < 0:
new_split = 0
else:
rank = self.comm.rank
ends = []
for pr in range(self.comm.size):
_, _, e = self.comm.chunk(self.shape, self.split, rank=pr)
ends.append(e[self.split].stop - e[self.split].start)
ends = torch.tensor(ends, device=self.device.torch_device)
chunk_ends = ends.cumsum(dim=0)
chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device)
chunk_start = chunk_starts[rank]
chunk_end = chunk_ends[rank]
arr = torch.tensor([], device=self.device.torch_device)
# all keys should be tuples here
gout = [0] * len(self.gshape)
# handle the dimensional reduction for integers
ints = sum(isinstance(it, int) for it in key)
gout = gout[: len(gout) - ints]
if self.split >= len(gout):
new_split = len(gout) - 1 if len(gout) - 1 > 0 else 0
else:
new_split = self.split
if len(key) == 0: # handle empty list
# this will return an array of shape (0, ...)
arr = self.__array[key]
gout_full = list(arr.shape)

""" At the end of the following if/elif/elif block the output array will be set.
each block handles the case where the element of the key along the split axis is a different type
and converts the key from global indices to local indices.
"""
if isinstance(key[self.split], (list, torch.Tensor, DNDarray)):
# advanced indexing, elements in the split dimension are adjusted to the local indices
lkey = list(key)
if isinstance(key[self.split], DNDarray):
lkey[self.split] = key[self.split].larray
inds = (
torch.tensor(
lkey[self.split], dtype=torch.long, device=self.device.torch_device
)
if not isinstance(lkey[self.split], torch.Tensor)
else lkey[self.split]
)
new_split = self.split
if len(key) == 0: # handle empty list
# this will return an array of shape (0, ...)
arr = self.__array[key]
gout_full = list(arr.shape)

""" At the end of the following if/elif/elif block the output array will be set.
each block handles the case where the element of the key along the split axis is a different type
and converts the key from global indices to local indices.
"""
if isinstance(key[self.split], (list, torch.Tensor, DNDarray)):
# advanced indexing, elements in the split dimension are adjusted to the local indices
lkey = list(key)
if isinstance(key[self.split], DNDarray):
lkey[self.split] = key[self.split].larray
inds = (
torch.tensor(lkey[self.split], dtype=torch.long, device=self.device.torch_device)
if not isinstance(lkey[self.split], torch.Tensor)
else lkey[self.split]
)

loc_inds = torch.where((inds >= chunk_start) & (inds < chunk_end))
if len(loc_inds[0]) != 0:
# if there are no local indices on a process, then `arr` is empty
inds = inds[loc_inds] - chunk_start
lkey[self.split] = inds
arr = self.__array[tuple(lkey)]
elif isinstance(key[self.split], slice):
# standard slicing along the split axis,
# adjust the slice start, stop, and step, then run it on the processes which have the requested data
key = list(key)
key_start = key[self.split].start if key[self.split].start is not None else 0
key_stop = (
key[self.split].stop
if key[self.split].stop is not None
else self.gshape[self.split]
)
if key_stop < 0:
key_stop = self.gshape[self.split] + key[self.split].stop
key_step = key[self.split].step
og_key_start = key_start
st_pr = torch.where(key_start < chunk_ends)[0]
st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size
sp_pr = torch.where(key_stop >= chunk_starts)[0]
sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0
actives = list(range(st_pr, sp_pr + 1))
if rank in actives:
key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank]
key_stop = ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank]
if key_step is not None and rank > actives[0]:
offset = (chunk_ends[rank - 1] - og_key_start) % key_step
if key_step > 2 and offset > 0:
key_start += key_step - offset
elif key_step == 2 and offset > 0:
key_start += (chunk_ends[rank - 1] - og_key_start) % key_step
if isinstance(key_start, torch.Tensor):
key_start = key_start.item()
if isinstance(key_stop, torch.Tensor):
key_stop = key_stop.item()
key[self.split] = slice(key_start, key_stop, key_step)
arr = self.__array[tuple(key)]

elif isinstance(key[self.split], int):
# if there is an integer in the key along the split axis, adjust it and then get `arr`
key = list(key)
key[self.split] = (
key[self.split] + self.gshape[self.split]
if key[self.split] < 0
else key[self.split]
)
if key[self.split] in range(chunk_start, chunk_end):
key[self.split] = key[self.split] - chunk_start
arr = self.__array[tuple(key)]

if 0 in arr.shape:
# arr is empty
# gout is all 0s as is the shape
warnings.warn(
"This process (rank: {}) is without data after slicing, "
"running the .balance_() function is recommended".format(self.comm.rank),
ResourceWarning,
)
loc_inds = torch.where((inds >= chunk_start) & (inds < chunk_end))
if len(loc_inds[0]) != 0:
# if there are no local indices on a process, then `arr` is empty
inds = inds[loc_inds] - chunk_start
lkey[self.split] = inds
arr = self.__array[tuple(lkey)]
elif isinstance(key[self.split], slice):
# standard slicing along the split axis,
# adjust the slice start, stop, and step, then run it on the processes which have the requested data
key = list(key)
key_start = key[self.split].start if key[self.split].start is not None else 0
key_stop = (
key[self.split].stop
if key[self.split].stop is not None
else self.gshape[self.split]
)
if key_stop < 0:
key_stop = self.gshape[self.split] + key[self.split].stop
key_step = key[self.split].step
og_key_start = key_start
st_pr = torch.where(key_start < chunk_ends)[0]
st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size
sp_pr = torch.where(key_stop >= chunk_starts)[0]
sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0
actives = list(range(st_pr, sp_pr + 1))
if rank in actives:
key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank]
key_stop = ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank]
if key_step is not None and rank > actives[0]:
offset = (chunk_ends[rank - 1] - og_key_start) % key_step
if key_step > 2 and offset > 0:
key_start += key_step - offset
elif key_step == 2 and offset > 0:
key_start += (chunk_ends[rank - 1] - og_key_start) % key_step
if isinstance(key_start, torch.Tensor):
key_start = key_start.item()
if isinstance(key_stop, torch.Tensor):
key_stop = key_stop.item()
key[self.split] = slice(key_start, key_stop, key_step)
arr = self.__array[tuple(key)]

elif isinstance(key[self.split], int):
# if there is an integer in the key along the split axis, adjust it and then get `arr`
key = list(key)
key[self.split] = (
key[self.split] + self.gshape[self.split]
if key[self.split] < 0
else key[self.split]
)
if key[self.split] in range(chunk_start, chunk_end):
key[self.split] = key[self.split] - chunk_start
arr = self.__array[tuple(key)]

return DNDarray(
arr.type(l_dtype),
gout_full if isinstance(gout_full, tuple) else tuple(gout_full),
self.dtype,
new_split,
self.device,
self.comm,
balanced=None,
if 0 in arr.shape:
# arr is empty
# gout is all 0s as is the shape
warnings.warn(
"This process (rank: {}) is without data after slicing, "
"running the .balance_() function is recommended".format(self.comm.rank),
ResourceWarning,
)

return DNDarray(
arr.type(l_dtype),
gout_full if isinstance(gout_full, tuple) else tuple(gout_full),
self.dtype,
new_split,
self.device,
self.comm,
balanced=None,
)

if torch.cuda.device_count() > 0:

def gpu(self):
Expand Down
7 changes: 7 additions & 0 deletions heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,13 @@ def test_rshift(self):
ht.array([True]) >> 2

def test_setitem_getitem(self):
# tests for bug 730:
a = ht.ones((10, 25, 30), split=1)
if a.comm.size > 1:
self.assertEqual(a[0].split, 0)
self.assertEqual(a[:, 0, :].split, 0)
self.assertEqual(a[:, :, 0].split, 1)

# set and get single value
a = ht.zeros((13, 5), split=0)
# set value on one node
Expand Down

0 comments on commit b511899

Please sign in to comment.