diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index c41877019..17d9d4adf 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -471,15 +471,28 @@ def __setattr__(self, name: str, val): object.__setattr__(pxy.deserialize(nbytes=self.__sizeof__()), name, val) self._pxy_set(pxy) - def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): - inputs = tuple( - o._pxy_deserialize() if isinstance(o, ProxyObject) else o for o in inputs - ) - kwargs = { - key: value._pxy_deserialize() if isinstance(value, ProxyObject) else value - for key, value in kwargs.items() - } - return self._pxy_deserialize().__array_ufunc__(ufunc, method, *inputs, **kwargs) + def __array_ufunc__(self, ufunc, method, *args, **kwargs): + from .proxify_device_objects import unproxify_device_objects + + args, kwargs = unproxify_device_objects(args), unproxify_device_objects(kwargs) + return self._pxy_deserialize().__array_ufunc__(ufunc, method, *args, **kwargs) + + def __array_function__(self, func, types, args, kwargs): + from .proxify_device_objects import unproxify_device_objects + + kwargs = unproxify_device_objects(kwargs) + proxied = self._pxy_deserialize() + + # Unproxify `args` and `types` + types = [t for t in types if not issubclass(t, type(self))] + args_proxied = [] + for a in args: + if isinstance(a, type(self)): + types.append(a.__class__) + args_proxied.append(a._pxy_deserialize()) + else: + args_proxied.append(a) + return proxied.__array_function__(func, types, args_proxied, kwargs) def __str__(self): return str(self._pxy_deserialize()) diff --git a/dask_cuda/tests/test_proxy.py b/dask_cuda/tests/test_proxy.py index 7fe461c95..417d0d83b 100644 --- a/dask_cuda/tests/test_proxy.py +++ b/dask_cuda/tests/test_proxy.py @@ -610,3 +610,13 @@ def test_sizeof_cudf(): pxy._pxy_cache = {} assert a_size == pytest.approx(sizeof(pxy), rel=1e-2) assert pxy._pxy_get().is_serialized() + + +def test_cupy_broadcast_to(): + cupy = pytest.importorskip("cupy") + a = cupy.arange(10) + a_b = np.broadcast_to(a, (10, 10)) + p_b = np.broadcast_to(proxy_object.asproxy(a), (10, 10)) + + assert a_b.shape == p_b.shape + assert (a_b == p_b).all()