diff --git a/setup.cfg b/setup.cfg index 42dc53bb882..ad0b12a3e32 100644 --- a/setup.cfg +++ b/setup.cfg @@ -138,6 +138,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-cftime.*] ignore_missing_imports = True +[mypy-cupy.*] +ignore_missing_imports = True [mypy-dask.*] ignore_missing_imports = True [mypy-distributed.*] diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index df579d23544..e82978ef600 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -14,7 +14,7 @@ from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast -from .pycompat import dask_array_type +from .pycompat import cupy_array_type, dask_array_type try: import dask.array as dask_array @@ -158,17 +158,23 @@ def trapz(y, x, axis): ) -def asarray(data): +def asarray(data, xp=np): return ( data if (isinstance(data, dask_array_type) or hasattr(data, "__array_function__")) - else np.asarray(data) + else xp.asarray(data) ) def as_shared_dtype(scalars_or_arrays): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - arrays = [asarray(x) for x in scalars_or_arrays] + + if any([isinstance(x, cupy_array_type) for x in scalars_or_arrays]): + import cupy as cp + + arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] + else: + arrays = [asarray(x) for x in scalars_or_arrays] # Pass arrays directly instead of dtypes to result_type so scalars # get handled properly. # Note that result_type() safely gets the dtype from dask arrays without diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index aaf52b9f295..dcb78d17cf8 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -17,3 +17,11 @@ sparse_array_type = (sparse.SparseArray,) except ImportError: # pragma: no cover sparse_array_type = () + +try: + # solely for isinstance checks + import cupy + + cupy_array_type = (cupy.ndarray,) +except ImportError: # pragma: no cover + cupy_array_type = () diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c505c749557..f9a41b2cee9 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -33,7 +33,7 @@ ) from .npcompat import IS_NEP18_ACTIVE from .options import _get_keep_attrs -from .pycompat import dask_array_type, integer_types +from .pycompat import cupy_array_type, dask_array_type, integer_types from .utils import ( OrderedSet, _default, @@ -45,9 +45,8 @@ ) NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( - indexing.ExplicitlyIndexed, - pd.Index, -) + dask_array_type + (indexing.ExplicitlyIndexed, pd.Index,) + dask_array_type + cupy_array_type +) # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) # type: ignore @@ -257,7 +256,10 @@ def _as_array_or_item(data): TODO: remove this (replace with np.asarray) once these issues are fixed """ - data = np.asarray(data) + if isinstance(data, cupy_array_type): + data = data.get() + else: + data = np.asarray(data) if data.ndim == 0: if data.dtype.kind == "M": data = np.datetime64(data, "ns") diff --git a/xarray/tests/test_cupy.py b/xarray/tests/test_cupy.py index 624e78d9271..0276b8ebc08 100644 --- a/xarray/tests/test_cupy.py +++ b/xarray/tests/test_cupy.py @@ -48,3 +48,13 @@ def test_check_data_stays_on_gpu(toy_weather_data): """Perform some operations and check the data stays on the GPU.""" freeze = (toy_weather_data["tmin"] <= 0).groupby("time.month").mean("time") assert isinstance(freeze.data, cp.core.core.ndarray) + + +def test_where(): + from xarray.core.duck_array_ops import where + + data = cp.zeros(10) + + output = where(data < 1, 1, data).all() + assert output + assert isinstance(output, cp.ndarray)