Skip to content

Commit

Permalink
ProxyObject implement __array_function__ (#843)
Browse files Browse the repository at this point in the history
Implements `__array_function__`, which should fix #841

cc. @pentschev

Authors:
  - Mads R. B. Kristensen (https://github.com/madsbk)

Approvers:
  - Peter Andreas Entschev (https://github.com/pentschev)

URL: #843
  • Loading branch information
madsbk authored Feb 2, 2022
1 parent 1068538 commit 13a3ede
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
31 changes: 22 additions & 9 deletions dask_cuda/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,15 +471,28 @@ def __setattr__(self, name: str, val):
object.__setattr__(pxy.deserialize(nbytes=self.__sizeof__()), name, val)
self._pxy_set(pxy)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
inputs = tuple(
o._pxy_deserialize() if isinstance(o, ProxyObject) else o for o in inputs
)
kwargs = {
key: value._pxy_deserialize() if isinstance(value, ProxyObject) else value
for key, value in kwargs.items()
}
return self._pxy_deserialize().__array_ufunc__(ufunc, method, *inputs, **kwargs)
def __array_ufunc__(self, ufunc, method, *args, **kwargs):
from .proxify_device_objects import unproxify_device_objects

args, kwargs = unproxify_device_objects(args), unproxify_device_objects(kwargs)
return self._pxy_deserialize().__array_ufunc__(ufunc, method, *args, **kwargs)

def __array_function__(self, func, types, args, kwargs):
from .proxify_device_objects import unproxify_device_objects

kwargs = unproxify_device_objects(kwargs)
proxied = self._pxy_deserialize()

# Unproxify `args` and `types`
types = [t for t in types if not issubclass(t, type(self))]
args_proxied = []
for a in args:
if isinstance(a, type(self)):
types.append(a.__class__)
args_proxied.append(a._pxy_deserialize())
else:
args_proxied.append(a)
return proxied.__array_function__(func, types, args_proxied, kwargs)

def __str__(self):
return str(self._pxy_deserialize())
Expand Down
10 changes: 10 additions & 0 deletions dask_cuda/tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,13 @@ def test_sizeof_cudf():
pxy._pxy_cache = {}
assert a_size == pytest.approx(sizeof(pxy), rel=1e-2)
assert pxy._pxy_get().is_serialized()


def test_cupy_broadcast_to():
cupy = pytest.importorskip("cupy")
a = cupy.arange(10)
a_b = np.broadcast_to(a, (10, 10))
p_b = np.broadcast_to(proxy_object.asproxy(a), (10, 10))

assert a_b.shape == p_b.shape
assert (a_b == p_b).all()

0 comments on commit 13a3ede

Please sign in to comment.