Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add out and where args for ht.div #945

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions heat/core/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional, Union, Tuple

from . import factories
from . import indexing
from . import manipulations
from . import _operations
from . import sanitation
Expand Down Expand Up @@ -427,7 +428,12 @@ def diff(
return ret


def div(t1: Union[DNDarray, float], t2: Union[DNDarray, float]) -> DNDarray:
def div(
t1: Union[DNDarray, float],
t2: Union[DNDarray, float],
out: Optional[DNDarray] = None,
where: Optional[DNDarray] = None,
) -> DNDarray:
"""
Element-wise true division of values of operand ``t1`` by values of operands ``t2`` (i.e ``t1/t2``).
Operation is not commutative.
Expand All @@ -438,6 +444,10 @@ def div(t1: Union[DNDarray, float], t2: Union[DNDarray, float]) -> DNDarray:
The first operand whose values are divided
t2: DNDarray or scalar
The second operand by whose values is divided
out: DNDarray, optional
The output array. It must have a shape that the inputs broadcast to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape and split dimension

where: DNDarray, optional
Condition of interest, where true yield divided value else yield original value in t1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should follow numpy.divide, so ht.divide should actually yield out where where is False (and uninitialized values when out=None)


Example
---------
Expand All @@ -453,7 +463,10 @@ def div(t1: Union[DNDarray, float], t2: Union[DNDarray, float]) -> DNDarray:
DNDarray([[2.0000, 1.0000],
[0.6667, 0.5000]], dtype=ht.float32, device=cpu:0, split=None)
"""
return _operations.__binary_op(torch.true_divide, t1, t2)
if where is not None:
t2 = indexing.where(where, t2, 1)

return _operations.__binary_op(torch.true_divide, t1, t2, out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should return out instead of t1 where where is False. And in fact this applies to all binary operations, as I admittedly had not realized when I created this issue.

So the way to go here would be to modify _operations.__binary_op to accomodate the where kwarg once and for all. Do you need help with that?



DNDarray.__truediv__ = lambda self, other: div(self, other)
Expand Down
18 changes: 18 additions & 0 deletions heat/core/tests/test_arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,24 @@ def test_div(self):
self.assertTrue(ht.equal(ht.div(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(ht.equal(ht.div(self.a_split_tensor, self.a_tensor), commutated_result))

a = out = ht.empty((2, 2))
ht.div(self.a_tensor, self.a_scalar, out=out)
self.assertTrue(ht.equal(out, result))
self.assertIs(a, out)
b = ht.array([[1.0, 2.0], [3.0, 4.0]])
ht.div(b, self.another_tensor, out=b)
self.assertTrue(ht.equal(b, result))

result_where = ht.array([[1.0, 2.0], [1.5, 2.0]])
self.assertTrue(
ht.equal(ht.div(self.a_tensor, self.a_scalar, where=self.a_tensor > 2), result_where)
)
ht.div(self.a_tensor, self.a_scalar)

a = self.a_tensor.copy()
ht.div(a, self.a_scalar, out=a, where=a > 2)
self.assertTrue(ht.equal(a, result_where))

ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
with self.assertRaises(ValueError):
ht.div(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
Expand Down