Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add out and where args for ht.div #945

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __binary_op(
t1: Union[DNDarray, int, float],
t2: Union[DNDarray, int, float],
out: Optional[DNDarray] = None,
where: Optional[DNDarray] = None,
fn_kwargs: Optional[Dict] = {},
) -> DNDarray:
"""
Expand All @@ -43,6 +44,8 @@ def __binary_op(
The second operand involved in the operation,
out: DNDarray, optional
Output buffer in which the result is placed
where: DNDarray, optional
Condition of interest, where true yield the result of the operation else yield original value in out (uninitialized when out=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use numpy's docs for where, I think they are a bit clearer. But we must expand on them a bit, e.g. is where supposed/expected to be distributed, and how.

fn_kwargs: Dict, optional
keyword arguments used for the given operation
Default: {} (empty dictionary)
Expand Down Expand Up @@ -101,6 +104,8 @@ def __binary_op(

# Make inputs have the same dimensionality
output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape)
if where is not None:
output_shape = stride_tricks.broadcast_shape(where.shape, output_shape)
# Broadcasting allows additional empty dimensions on the left side
# TODO simplify this once newaxis-indexing is supported to get rid of the loops
while len(t1.shape) < len(output_shape):
Expand Down Expand Up @@ -163,23 +168,24 @@ def __get_out_params(target, other=None, map=None):
if out is not None:
sanitation.sanitize_out(out, output_shape, output_split, output_device, output_comm)
t1, t2 = sanitation.sanitize_distribution(t1, t2, target=out)
out.larray[:] = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs
else:
out_tensor = torch.empty(output_shape, dtype=promoted_type)
Copy link
Contributor

@ClaudiaComito ClaudiaComito Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 comments here:

  • output_shape is the global (memory-distributed) shape of the output DNDarray, here you're initializing a potentially huge torch tensor. In this case you should call
factories.empty(output_shape, dtype=..., split=..., device=...)

and that will take care of only initializing slices of the global array on each process

(I think this is also why the tests fail btw)

  • if I understand the numpy docs correctly, this empty out DNDarray only needs to be initialized if where is not None.

out = DNDarray(
out_tensor,
output_shape,
types.heat_type_of(out_tensor),
output_split,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)
return out
# print(t1.lshape, t2.lshape)

result = operation(t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs)
result = operation(t1.larray.to(promoted_type), t2.larray.to(promoted_type), **fn_kwargs)
if where is not None:
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
result = torch.where(where.larray, result, out.larray)
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved

return DNDarray(
result,
output_shape,
types.heat_type_of(result),
output_split,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)
out.larray.copy_(result)
return out


def __cum_op(
Expand Down
13 changes: 11 additions & 2 deletions heat/core/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,12 @@ def diff(
return ret


def div(t1: Union[DNDarray, float], t2: Union[DNDarray, float]) -> DNDarray:
def div(
t1: Union[DNDarray, float],
t2: Union[DNDarray, float],
out: Optional[DNDarray] = None,
where: Optional[DNDarray] = None,
) -> DNDarray:
"""
Element-wise true division of values of operand ``t1`` by values of operands ``t2`` (i.e ``t1/t2``).
Operation is not commutative.
Expand All @@ -438,6 +443,10 @@ def div(t1: Union[DNDarray, float], t2: Union[DNDarray, float]) -> DNDarray:
The first operand whose values are divided
t2: DNDarray or scalar
The second operand by whose values is divided
out: DNDarray, optional
The output array. It must have a shape that the inputs broadcast to and matching split axis
where: DNDarray, optional
Condition of interest, where true yield divided value else yield original value in out (uninitialized when out=None)

Example
---------
Expand All @@ -453,7 +462,7 @@ def div(t1: Union[DNDarray, float], t2: Union[DNDarray, float]) -> DNDarray:
DNDarray([[2.0000, 1.0000],
[0.6667, 0.5000]], dtype=ht.float32, device=cpu:0, split=None)
"""
return _operations.__binary_op(torch.true_divide, t1, t2)
return _operations.__binary_op(torch.true_divide, t1, t2, out, where)


DNDarray.__truediv__ = lambda self, other: div(self, other)
Expand Down
26 changes: 26 additions & 0 deletions heat/core/tests/test_arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,32 @@ def test_div(self):
self.assertTrue(ht.equal(ht.div(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(ht.equal(ht.div(self.a_split_tensor, self.a_tensor), commutated_result))

a = out = ht.empty((2, 2))
ht.div(self.a_tensor, self.a_scalar, out=out)
self.assertTrue(ht.equal(out, result))
self.assertIs(a, out)
b = ht.array([[1.0, 2.0], [3.0, 4.0]])
ht.div(b, self.another_tensor, out=b)
self.assertTrue(ht.equal(b, result))

result_where = ht.array([[1.0, 2.0], [1.5, 2.0]])
self.assertTrue(
ht.equal(
ht.div(self.a_tensor, self.a_scalar, where=self.a_tensor > 2)[1, :],
result_where[1, :],
)
)
a = self.a_tensor.copy()
ht.div(a, self.a_scalar, out=a, where=a > 2)
self.assertTrue(ht.equal(a, result_where))
result_where_broadcasted = ht.array([[1.0, 1.0], [3.0, 2.0]])
a = self.a_tensor.copy()
ht.div(a, self.a_scalar, out=a, where=ht.array([False, True]))
self.assertTrue(ht.equal(a, result_where_broadcasted))
a = ht.array([[1.0, 2.0], [3.0, 4.0]], split=1)
ht.div(a, self.another_tensor, out=a, where=ht.array([[False], [True]], split=0))
self.assertTrue(ht.equal(a, result_where))

ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
with self.assertRaises(ValueError):
ht.div(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
Expand Down