Skip to content

Commit

Permalink
Merge pull request #840 from helmholtz-analytics/enhancement/839-vecdot
Browse files Browse the repository at this point in the history
implement vecdot
  • Loading branch information
coquelin77 authored Aug 2, 2021
2 parents f0afedf + cd6881a commit 50a743a
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 2 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
## Bug Fixes
- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension

# Feature Additions
## Feature Additions

### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
## Manipulations
- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll`

Expand Down
55 changes: 54 additions & 1 deletion heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,18 @@
from .. import sanitation
from .. import types

__all__ = ["dot", "matmul", "norm", "outer", "projection", "trace", "transpose", "tril", "triu"]
__all__ = [
"dot",
"matmul",
"norm",
"outer",
"projection",
"trace",
"transpose",
"tril",
"triu",
"vecdot",
]


def dot(a: DNDarray, b: DNDarray, out: Optional[DNDarray] = None) -> Union[DNDarray, float]:
Expand All @@ -39,6 +50,11 @@ def dot(a: DNDarray, b: DNDarray, out: Optional[DNDarray] = None) -> Union[DNDar
Second input DNDarray
out : DNDarray, optional
Output buffer.
See Also
--------
vecdot
Supports (vector) dot along an axis.
"""
if isinstance(a, (float, int)) or isinstance(b, (float, int)) or a.ndim == 0 or b.ndim == 0:
# 3. If either a or b is 0-D (scalar), it is equivalent to multiply and using numpy.multiply(a, b) or a * b is preferred.
Expand Down Expand Up @@ -1638,3 +1654,40 @@ def triu(m: DNDarray, k: int = 0) -> DNDarray:

DNDarray.triu: Callable[[DNDarray, int], DNDarray] = lambda self, k=0: triu(self, k)
DNDarray.triu.__doc__ = triu.__doc__


def vecdot(
x1: DNDarray, x2: DNDarray, axis: Optional[int] = None, keepdim: Optional[bool] = None
) -> DNDarray:
"""
Computes the (vector) dot product of two DNDarrays.
Parameters
----------
x1 : DNDarray
first input array.
x2 : DNDarray
second input array. Must be compatible with x1.
axis : int, optional
axis over which to compute the dot product. The last dimension is used if 'None'.
keepdim : bool, optional
If this is set to 'True', the axes which are reduced are left in the result as dimensions with size one.
See Also
--------
dot
NumPy-like dot function.
Examples
--------
>>> ht.vecdot(ht.full((3,3,3),3), ht.ones((3,3)), axis=0)
DNDarray([[9., 9., 9.],
[9., 9., 9.],
[9., 9., 9.]], dtype=ht.float32, device=cpu:0, split=None)
"""
m = arithmetics.mul(x1, x2)

if axis is None:
axis = m.ndim - 1

return arithmetics.sum(m, axis=axis, keepdim=keepdim)
18 changes: 18 additions & 0 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,3 +1672,21 @@ def test_triu(self):
self.assertTrue(result.larray[-1, 0] == 0)
if result.comm.rank == result.shape[0] - 1:
self.assertTrue(result.larray[0, -1] == 1)

def test_vecdot(self):
a = ht.array([1, 1, 1])
b = ht.array([1, 2, 3])

c = ht.linalg.vecdot(a, b)

self.assertEqual(c.dtype, ht.int64)
self.assertEqual(c.device, a.device)
self.assertTrue(ht.equal(c, ht.array([6])))

a = ht.full((4, 4), 2, split=0)
b = ht.ones(4)

c = ht.linalg.vecdot(a, b, axis=0, keepdim=True)
self.assertEqual(c.dtype, ht.float32)
self.assertEqual(c.device, a.device)
self.assertTrue(ht.equal(c, ht.array([[8, 8, 8, 8]])))

0 comments on commit 50a743a

Please sign in to comment.