Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add out and where args for ht.div #945

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@
## Feature Additions

### Arithmetics
- - [#887](https://github.com/helmholtz-analytics/heat/pull/887) Binary operations now support operands of equal shapes, equal `split` axes, but different distribution maps.

## Feature additions
- [#887](https://github.com/helmholtz-analytics/heat/pull/887) Binary operations now support operands of equal shapes, equal `split` axes, but different distribution maps.
- [#945](https://github.com/helmholtz-analytics/heat/pull/945) `div` now supports `out` and `where` kwargs
### Communication
- [#868](https://github.com/helmholtz-analytics/heat/pull/868) New `MPICommunication` method `Split`
### DNDarray
- [#856](https://github.com/helmholtz-analytics/heat/pull/856) New `DNDarray` method `__torch_proxy__`
- [#885](https://github.com/helmholtz-analytics/heat/pull/885) New `DNDarray` method `conj`

# Feature additions
### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot`
Expand Down
43 changes: 29 additions & 14 deletions heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __binary_op(
t1: Union[DNDarray, int, float],
t2: Union[DNDarray, int, float],
out: Optional[DNDarray] = None,
where: Optional[DNDarray] = None,
fn_kwargs: Optional[Dict] = {},
) -> DNDarray:
"""
Expand All @@ -43,6 +44,8 @@ def __binary_op(
The second operand involved in the operation,
out: DNDarray, optional
Output buffer in which the result is placed
where: DNDarray, optional
Condition of interest, where true yield the result of the operation else yield original value in out (uninitialized when out=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use numpy's docs for where, I think they are a bit clearer. But we must expand on them a bit, e.g. is where supposed/expected to be distributed, and how.

fn_kwargs: Dict, optional
keyword arguments used for the given operation
Default: {} (empty dictionary)
Expand Down Expand Up @@ -101,6 +104,8 @@ def __binary_op(

# Make inputs have the same dimensionality
output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape)
if where is not None:
output_shape = stride_tricks.broadcast_shape(where.shape, output_shape)
# Broadcasting allows additional empty dimensions on the left side
# TODO simplify this once newaxis-indexing is supported to get rid of the loops
while len(t1.shape) < len(output_shape):
Expand Down Expand Up @@ -163,23 +168,33 @@ def __get_out_params(target, other=None, map=None):
if out is not None:
sanitation.sanitize_out(out, output_shape, output_split, output_device, output_comm)
t1, t2 = sanitation.sanitize_distribution(t1, t2, target=out)
out.larray[:] = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs

result = operation(t1.larray.to(promoted_type), t2.larray.to(promoted_type), **fn_kwargs)

if out is None and where is None:
return DNDarray(
result,
output_shape,
types.heat_type_of(result),
output_split,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)
return out
# print(t1.lshape, t2.lshape)

result = operation(t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs)
if where is not None:
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
if out is None:
out = factories.empty(
output_shape,
dtype=promoted_type,
split=output_split,
device=output_device,
comm=output_comm,
)
result = torch.where(where.larray, result, out.larray)
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved

return DNDarray(
result,
output_shape,
types.heat_type_of(result),
output_split,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)
out.larray.copy_(result)
return out


def __cum_op(
Expand Down
13 changes: 11 additions & 2 deletions heat/core/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,12 @@ def diff(
return ret


def div(t1: Union[DNDarray, float], t2: Union[DNDarray, float]) -> DNDarray:
def div(
t1: Union[DNDarray, float],
t2: Union[DNDarray, float],
out: Optional[DNDarray] = None,
where: Optional[DNDarray] = None,
) -> DNDarray:
"""
Element-wise true division of values of operand ``t1`` by values of operands ``t2`` (i.e ``t1/t2``).
Operation is not commutative.
Expand All @@ -438,6 +443,10 @@ def div(t1: Union[DNDarray, float], t2: Union[DNDarray, float]) -> DNDarray:
The first operand whose values are divided
t2: DNDarray or scalar
The second operand by whose values is divided
out: DNDarray, optional
The output array. It must have a shape that the inputs broadcast to and matching split axis
where: DNDarray, optional
Condition of interest, where true yield divided value else yield original value in out (uninitialized when out=None)

Example
---------
Expand All @@ -453,7 +462,7 @@ def div(t1: Union[DNDarray, float], t2: Union[DNDarray, float]) -> DNDarray:
DNDarray([[2.0000, 1.0000],
[0.6667, 0.5000]], dtype=ht.float32, device=cpu:0, split=None)
"""
return _operations.__binary_op(torch.true_divide, t1, t2)
return _operations.__binary_op(torch.true_divide, t1, t2, out, where)


DNDarray.__truediv__ = lambda self, other: div(self, other)
Expand Down
30 changes: 30 additions & 0 deletions heat/core/tests/test_arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,42 @@ def test_div(self):
self.assertTrue(ht.equal(ht.div(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(ht.equal(ht.div(self.a_split_tensor, self.a_tensor), commutated_result))

a = out = ht.empty((2, 2))
ht.div(self.a_tensor, self.a_scalar, out=out)
self.assertTrue(ht.equal(out, result))
self.assertIs(a, out)
b = ht.array([[1.0, 2.0], [3.0, 4.0]])
ht.div(b, self.another_tensor, out=b)
self.assertTrue(ht.equal(b, result))
out = ht.empty((2, 2), split=self.a_split_tensor.split)
ht.div(self.a_split_tensor, self.a_tensor, out=out)
self.assertTrue(ht.equal(out, commutated_result))
self.assertEqual(self.a_split_tensor.split, out.split)

result_where = ht.array([[1.0, 2.0], [1.5, 2.0]])
self.assertTrue(
ht.equal(
ht.div(self.a_tensor, self.a_scalar, where=self.a_tensor > 2)[1, :],
result_where[1, :],
)
)

a = self.a_tensor.copy()
ht.div(a, self.a_scalar, out=a, where=a > 2)
self.assertTrue(ht.equal(a, result_where))
result_where_broadcasted = ht.array([[1.0, 1.0], [3.0, 2.0]])
a = self.a_tensor.copy()
ht.div(a, self.a_scalar, out=a, where=ht.array([False, True]))
self.assertTrue(ht.equal(a, result_where_broadcasted))

ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
with self.assertRaises(ValueError):
ht.div(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.div(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.div("T", "s")
with self.assertRaises(ValueError):
ht.div(self.a_split_tensor, self.a_tensor, out=ht.empty((2, 2), split=None))

def test_fmod(self):
result = ht.array([[1.0, 0.0], [1.0, 0.0]])
Expand Down