Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/369 reduce op empty tensor #443

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions heat/core/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,9 @@ def prod(x, axis=None, out=None, keepdim=None):
], axis=1)
ht.tensor([ 2., 12.])
"""
return operations.__reduce_op(x, torch.prod, MPI.PROD, axis=axis, out=out, keepdim=keepdim)
return operations.__reduce_op(
x, torch.prod, MPI.PROD, axis=axis, out=out, neutral=torch.ones, keepdim=keepdim
)


def sub(t1, t2):
Expand Down Expand Up @@ -679,4 +681,6 @@ def sum(x, axis=None, out=None, keepdim=None):
[3.]]])
"""
# TODO: make me more numpy API complete Issue #101
return operations.__reduce_op(x, torch.sum, MPI.SUM, axis=axis, out=out, keepdim=keepdim)
return operations.__reduce_op(
x, torch.sum, MPI.SUM, axis=axis, out=out, neutral=torch.zeros, keepdim=keepdim
)
8 changes: 6 additions & 2 deletions heat/core/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def all(x, axis=None, out=None, keepdim=None):
def local_all(t, *args, **kwargs):
return torch.all(t != 0, *args, **kwargs)

return operations.__reduce_op(x, local_all, MPI.LAND, axis=axis, out=out, keepdim=keepdim)
return operations.__reduce_op(
x, local_all, MPI.LAND, axis=axis, out=out, neutral=torch.ones, keepdim=keepdim
)


def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False):
Expand Down Expand Up @@ -206,4 +208,6 @@ def any(x, axis=None, out=None, keepdim=False):
def local_any(t, *args, **kwargs):
return torch.any(t != 0, *args, **kwargs)

return operations.__reduce_op(x, local_any, MPI.LOR, axis=axis, out=out, keepdim=keepdim)
return operations.__reduce_op(
x, local_any, MPI.LOR, axis=axis, out=out, neutral=torch.ones, keepdim=keepdim
)
23 changes: 19 additions & 4 deletions heat/core/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,18 +385,33 @@ def __reduce_op(x, partial_op, reduction_op, **kwargs):

# no further checking needed, sanitize axis will raise the proper exceptions
axis = stride_tricks.sanitize_axis(x.shape, kwargs.get("axis"))
split = x.split
keepdim = kwargs.get("keepdim")
split = x.split
no_data = False
# check if local tensor is empty
if split is not None:
no_data = x.lshape[split] == 0
if no_data:
# local tensor contains no data, replace with neutral element
neutral_element = kwargs.get("neutral")
if neutral_element is None:
# warnings.warn(
# "Local tensor has no data and argument 'neutral' is None. Setting 'neutral' to torch.ones"
# )
neutral_element = torch.ones
neutral_shape = x.lshape[:split] + (1,) + x.lshape[split + 1 :]
neutral_partial = neutral_element(neutral_shape)

partial = x._DNDarray__array if not no_data else neutral_partial

if axis is None:
partial = partial_op(x._DNDarray__array).reshape(-1)
partial = partial_op(partial).reshape(-1)
output_shape = (1,)
else:
if isinstance(axis, int):
axis = (axis,)

if isinstance(axis, tuple):
partial = x._DNDarray__array
if isinstance(axis, tuple): # TODO: do we need this check at all?? axis has been sanitized
output_shape = x.gshape
for dim in axis:
partial = partial_op(partial, dim=dim, keepdim=True)
Expand Down
13 changes: 9 additions & 4 deletions heat/core/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ def local_argmax(*args, **kwargs):
raise TypeError("axis must be None or int, but was {}".format(type(axis)))

# perform the global reduction
# TODO: define neutral_element for local_argmax
reduced_result = operations.__reduce_op(
x, local_argmax, MPI_ARGMAX, axis=axis, out=None, **kwargs
x, local_argmax, MPI_ARGMAX, axis=axis, out=None, **kwargs # neutral=torch.ones,
)

# correct the tensor
Expand Down Expand Up @@ -195,7 +196,7 @@ def local_argmin(*args, **kwargs):

# perform the global reduction
reduced_result = operations.__reduce_op(
x, local_argmin, MPI_ARGMIN, axis=axis, out=None, **kwargs
x, local_argmin, MPI_ARGMIN, axis=axis, out=None, neutral=torch.ones, **kwargs
)

# correct the tensor
Expand Down Expand Up @@ -508,7 +509,9 @@ def local_max(*args, **kwargs):
result = result[0]
return result

return operations.__reduce_op(x, local_max, MPI.MAX, axis=axis, out=out, keepdim=keepdim)
return operations.__reduce_op(
x, local_max, MPI.MAX, axis=axis, out=out, neutral=torch.ones, keepdim=keepdim
)


def maximum(x1, x2, out=None):
Expand Down Expand Up @@ -980,7 +983,9 @@ def local_min(*args, **kwargs):
result = result[0]
return result

return operations.__reduce_op(x, local_min, MPI.MIN, axis=axis, out=out, keepdim=keepdim)
return operations.__reduce_op(
x, local_min, MPI.MIN, axis=axis, out=out, neutral=torch.ones, keepdim=keepdim
)


def minimum(x1, x2, out=None):
Expand Down
Loading