Skip to content

Commit

Permalink
ProxyObject to support matrix multiplication (#849)
Browse files Browse the repository at this point in the history
Implements `ProxyObject.__matmul__` and `ProxyObject.__imatmul__`.

@pentschev can you check if this fixes #843 (comment)

Authors:
  - Mads R. B. Kristensen (https://github.com/madsbk)

Approvers:
  - Peter Andreas Entschev (https://github.com/pentschev)

URL: #849
  • Loading branch information
madsbk authored Feb 2, 2022
1 parent 13a3ede commit 9a92864
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
11 changes: 11 additions & 0 deletions dask_cuda/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,9 @@ def __xor__(self, other):
def __or__(self, other):
return self._pxy_deserialize() | other

def __matmul__(self, other):
return self._pxy_deserialize().__matmul__(unproxy(other))

def __radd__(self, other):
return other + self._pxy_deserialize()

Expand Down Expand Up @@ -741,6 +744,14 @@ def __ior__(self, other):
self._pxy_set(pxy)
return self

def __imatmul__(self, other):
pxy = self._pxy_get(copy=True)
proxied = pxy.deserialize(nbytes=self.__sizeof__())
proxied @= other
pxy.obj = proxied
self._pxy_set(pxy)
return self

def __neg__(self):
return -self._pxy_deserialize()

Expand Down
24 changes: 24 additions & 0 deletions dask_cuda/tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,27 @@ def test_cupy_broadcast_to():

assert a_b.shape == p_b.shape
assert (a_b == p_b).all()


def test_cupy_matmul():
cupy = pytest.importorskip("cupy")
a, b = cupy.arange(10), cupy.arange(10)
c = a @ b
assert c == proxy_object.asproxy(a) @ b
assert c == a @ proxy_object.asproxy(b)
assert c == proxy_object.asproxy(a) @ proxy_object.asproxy(b)


def test_cupy_imatmul():
cupy = pytest.importorskip("cupy")
a = cupy.arange(9).reshape(3, 3)
c = a.copy()
c @= a

a1 = a.copy()
a1 @= proxy_object.asproxy(a)
assert (a1 == c).all()

a2 = proxy_object.asproxy(a.copy())
a2 @= a
assert (a2 == c).all()

0 comments on commit 9a92864

Please sign in to comment.