Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add out and where args for ht.div #945

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@
## Feature Additions

### Arithmetics
- - [#887](https://github.com/helmholtz-analytics/heat/pull/887) Binary operations now support operands of equal shapes, equal `split` axes, but different distribution maps.

## Feature additions
- [#887](https://github.com/helmholtz-analytics/heat/pull/887) Binary operations now support operands of equal shapes, equal `split` axes, but different distribution maps.
- [#945](https://github.com/helmholtz-analytics/heat/pull/945) `div` now supports `out` and `where` kwargs
### Communication
- [#868](https://github.com/helmholtz-analytics/heat/pull/868) New `MPICommunication` method `Split`
### DNDarray
- [#856](https://github.com/helmholtz-analytics/heat/pull/856) New `DNDarray` method `__torch_proxy__`
- [#885](https://github.com/helmholtz-analytics/heat/pull/885) New `DNDarray` method `conj`

# Feature additions
### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot`
Expand Down
58 changes: 41 additions & 17 deletions heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __binary_op(
t1: Union[DNDarray, int, float],
t2: Union[DNDarray, int, float],
out: Optional[DNDarray] = None,
where: Optional[DNDarray] = None,
fn_kwargs: Optional[Dict] = {},
) -> DNDarray:
"""
Expand All @@ -38,11 +39,17 @@ def __binary_op(
The operation to be performed. Function that performs operation elements-wise on the involved tensors,
e.g. add values from other to self
t1: DNDarray or scalar
The first operand involved in the operation,
The first operand involved in the operation.
t2: DNDarray or scalar
The second operand involved in the operation,
The second operand involved in the operation.
out: DNDarray, optional
Output buffer in which the result is placed
Output buffer in which the result is placed. If not provided, a freshly allocated array is returned.
where: DNDarray, optional
Condition to broadcast over the inputs. At locations where the condition is True, the `out` array
will be set to the result of the operation. Elsewhere, the `out` array will retain its original
value. If an uninitialized `out` array is created via the default `out=None`, locations within
it where the condition is False will remain uninitialized. If distributed, the split axis (after
broadcasting if required) must match that of the `out` array.
fn_kwargs: Dict, optional
keyword arguments used for the given operation
Default: {} (empty dictionary)
Expand Down Expand Up @@ -101,12 +108,17 @@ def __binary_op(

# Make inputs have the same dimensionality
output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape)
if where is not None:
output_shape = stride_tricks.broadcast_shape(where.shape, output_shape)
# Broadcasting allows additional empty dimensions on the left side
# TODO simplify this once newaxis-indexing is supported to get rid of the loops
while len(t1.shape) < len(output_shape):
t1 = t1.expand_dims(axis=0)
while len(t2.shape) < len(output_shape):
t2 = t2.expand_dims(axis=0)
if where is not None:
while len(where.shape) < len(output_shape):
where = where.expand_dims(axis=0)
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
# t1 = t1[tuple([None] * (len(output_shape) - t1.ndim))]
# t2 = t2[tuple([None] * (len(output_shape) - t2.ndim))]
# print(t1.lshape, t2.lshape)
Expand Down Expand Up @@ -163,23 +175,35 @@ def __get_out_params(target, other=None, map=None):
if out is not None:
sanitation.sanitize_out(out, output_shape, output_split, output_device, output_comm)
t1, t2 = sanitation.sanitize_distribution(t1, t2, target=out)
out.larray[:] = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs

result = operation(t1.larray.to(promoted_type), t2.larray.to(promoted_type), **fn_kwargs)

if out is None and where is None:
return DNDarray(
result,
output_shape,
types.heat_type_of(result),
output_split,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)
return out
# print(t1.lshape, t2.lshape)

result = operation(t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs)
if where is not None:
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
if out is None:
out = factories.empty(
output_shape,
dtype=promoted_type,
split=output_split,
device=output_device,
comm=output_comm,
)
if where.split != out.split:
where = sanitation.sanitize_distribution(where, target=out)
result = torch.where(where.larray, result, out.larray)
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved

return DNDarray(
result,
output_shape,
types.heat_type_of(result),
output_split,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)
out.larray.copy_(result)
return out


def __cum_op(
Expand Down
22 changes: 18 additions & 4 deletions heat/core/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,17 +427,31 @@ def diff(
return ret


def div(t1: Union[DNDarray, float], t2: Union[DNDarray, float]) -> DNDarray:
def div(
t1: Union[DNDarray, float],
t2: Union[DNDarray, float],
out: Optional[DNDarray] = None,
where: Optional[DNDarray] = None,
) -> DNDarray:
"""
Element-wise true division of values of operand ``t1`` by values of operands ``t2`` (i.e ``t1/t2``).
Operation is not commutative.

Parameters
----------
t1: DNDarray or scalar
The first operand whose values are divided
The first operand whose values are divided.
t2: DNDarray or scalar
The second operand by whose values is divided
The second operand by whose values is divided.
out: DNDarray, optional
The output array. It must have a shape that the inputs broadcast to and matching split axis.
If not provided, a freshly allocated array is returned.
where: DNDarray, optional
Condition to broadcast over the inputs. At locations where the condition is True, the `out` array
will be set to the divided value. Elsewhere, the `out` array will retain its original value. If
an uninitialized `out` array is created via the default `out=None`, locations within it where the
condition is False will remain uninitialized. If distributed, the split axis (after broadcasting
if required) must match that of the `out` array.

Example
---------
Expand All @@ -453,7 +467,7 @@ def div(t1: Union[DNDarray, float], t2: Union[DNDarray, float]) -> DNDarray:
DNDarray([[2.0000, 1.0000],
[0.6667, 0.5000]], dtype=ht.float32, device=cpu:0, split=None)
"""
return _operations.__binary_op(torch.true_divide, t1, t2)
return _operations.__binary_op(torch.true_divide, t1, t2, out, where)


DNDarray.__truediv__ = lambda self, other: div(self, other)
Expand Down
51 changes: 51 additions & 0 deletions heat/core/tests/test_arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,63 @@ def test_div(self):
self.assertTrue(ht.equal(ht.div(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(ht.equal(ht.div(self.a_split_tensor, self.a_tensor), commutated_result))

a = out = ht.empty((2, 2))
ht.div(self.a_tensor, self.a_scalar, out=out)
self.assertTrue(ht.equal(out, result))
self.assertIs(a, out)
b = ht.array([[1.0, 2.0], [3.0, 4.0]])
ht.div(b, self.another_tensor, out=b)
self.assertTrue(ht.equal(b, result))
out = ht.empty((2, 2), split=self.a_split_tensor.split)
ht.div(self.a_split_tensor, self.a_tensor, out=out)
self.assertTrue(ht.equal(out, commutated_result))
self.assertEqual(self.a_split_tensor.split, out.split)

result_where = ht.array([[1.0, 2.0], [1.5, 2.0]])
self.assertTrue(
ht.equal(
ht.div(self.a_tensor, self.a_scalar, where=self.a_tensor > 2)[1, :],
result_where[1, :],
)
)

a = self.a_tensor.copy()
ht.div(a, self.a_scalar, out=a, where=a > 2)
self.assertTrue(ht.equal(a, result_where))
out = ht.array([[1.0, 2.0], [3.0, 4.0]], split=1)
where = ht.array([[True, True], [False, True]], split=None)
ht.div(out, self.another_tensor, out=out, where=where)
self.assertTrue(ht.equal(out, ht.array([[0.5, 1.0], [3.0, 2.0]])))
self.assertEqual(1, out.split)
out = ht.array([[1.0, 2.0], [3.0, 4.0]], split=0)
where.resplit_(0)
ht.div(out, self.another_tensor, out=out, where=where)
self.assertTrue(ht.equal(out, ht.array([[0.5, 1.0], [3.0, 2.0]])))
self.assertEqual(0, out.split)

result_where_broadcasted = ht.array([[1.0, 1.0], [3.0, 2.0]])
a = self.a_tensor.copy()
ht.div(a, self.a_scalar, out=a, where=ht.array([False, True]))
self.assertTrue(ht.equal(a, result_where_broadcasted))
a = self.a_tensor.copy().resplit_(0)
ht.div(a, self.a_scalar, out=a, where=ht.array([False, True], split=0))
self.assertTrue(ht.equal(a, result_where_broadcasted))
self.assertEqual(0, a.split)

ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
with self.assertRaises(ValueError):
ht.div(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.div(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.div("T", "s")
with self.assertRaises(ValueError):
ht.div(self.a_split_tensor, self.a_tensor, out=ht.empty((2, 2), split=None))
with self.assertRaises(NotImplementedError):
ht.div(
self.a_split_tensor,
self.a_tensor,
where=ht.array([[True, False], [False, True]], split=1),
)

def test_fmod(self):
result = ht.array([[1.0, 0.0], [1.0, 0.0]])
Expand Down