From 77d87aeb5f152baa61cdd273c00d7936004460c1 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Thu, 16 Jul 2020 17:51:15 +0100 Subject: [PATCH 1/5] Support cupy in as_shared_dtype --- xarray/core/duck_array_ops.py | 14 ++++++++++---- xarray/core/pycompat.py | 8 ++++++++ xarray/core/variable.py | 12 +++++++----- xarray/tests/test_cupy.py | 8 ++++++++ 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index df579d23544..71079f13d1e 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -14,7 +14,7 @@ from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast -from .pycompat import dask_array_type +from .pycompat import dask_array_type, cupy_array_type try: import dask.array as dask_array @@ -158,17 +158,23 @@ def trapz(y, x, axis): ) -def asarray(data): +def asarray(data, xp=np): return ( data if (isinstance(data, dask_array_type) or hasattr(data, "__array_function__")) - else np.asarray(data) + else xp.asarray(data) ) def as_shared_dtype(scalars_or_arrays): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - arrays = [asarray(x) for x in scalars_or_arrays] + + if any([isinstance(x, cupy_array_type) for x in scalars_or_arrays]): + import cupy as cp + + arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] + else: + arrays = [asarray(x) for x in scalars_or_arrays] # Pass arrays directly instead of dtypes to result_type so scalars # get handled properly. # Note that result_type() safely gets the dtype from dask arrays without diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index aaf52b9f295..20089c767b5 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -17,3 +17,11 @@ sparse_array_type = (sparse.SparseArray,) except ImportError: # pragma: no cover sparse_array_type = () + +try: + # solely for isinstance checks + import cupy + + cupy_array_type = (cupy.core.core.ndarray,) +except ImportError: # pragma: no cover + cupy_array_type = () diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c505c749557..f9a41b2cee9 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -33,7 +33,7 @@ ) from .npcompat import IS_NEP18_ACTIVE from .options import _get_keep_attrs -from .pycompat import dask_array_type, integer_types +from .pycompat import cupy_array_type, dask_array_type, integer_types from .utils import ( OrderedSet, _default, @@ -45,9 +45,8 @@ ) NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( - indexing.ExplicitlyIndexed, - pd.Index, -) + dask_array_type + (indexing.ExplicitlyIndexed, pd.Index,) + dask_array_type + cupy_array_type +) # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) # type: ignore @@ -257,7 +256,10 @@ def _as_array_or_item(data): TODO: remove this (replace with np.asarray) once these issues are fixed """ - data = np.asarray(data) + if isinstance(data, cupy_array_type): + data = data.get() + else: + data = np.asarray(data) if data.ndim == 0: if data.dtype.kind == "M": data = np.datetime64(data, "ns") diff --git a/xarray/tests/test_cupy.py b/xarray/tests/test_cupy.py index 624e78d9271..f9b4fbe9c28 100644 --- a/xarray/tests/test_cupy.py +++ b/xarray/tests/test_cupy.py @@ -48,3 +48,11 @@ def test_check_data_stays_on_gpu(toy_weather_data): """Perform some operations and check the data stays on the GPU.""" freeze = (toy_weather_data["tmin"] <= 0).groupby("time.month").mean("time") assert isinstance(freeze.data, cp.core.core.ndarray) + + +def test_where(): + from xarray.core.duck_array_ops import where + + data = cp.zeros(10) + + assert where(data < 1, 1, data).all() From 2dbd79abffc1785543209ad5c540877db51efad6 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Thu, 16 Jul 2020 17:52:19 +0100 Subject: [PATCH 2/5] Lint --- xarray/core/duck_array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 71079f13d1e..e82978ef600 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -14,7 +14,7 @@ from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast -from .pycompat import dask_array_type, cupy_array_type +from .pycompat import cupy_array_type, dask_array_type try: import dask.array as dask_array From 557b624f1f95d75d61e3f6bd74b49ebf912f0108 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Thu, 23 Jul 2020 17:07:12 +0100 Subject: [PATCH 3/5] Update xarray/core/pycompat.py --- xarray/core/pycompat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 20089c767b5..dcb78d17cf8 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -22,6 +22,6 @@ # solely for isinstance checks import cupy - cupy_array_type = (cupy.core.core.ndarray,) + cupy_array_type = (cupy.ndarray,) except ImportError: # pragma: no cover cupy_array_type = () From 72fc2fe368ed66d68e136ae2042885182164b162 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Fri, 24 Jul 2020 15:11:06 +0100 Subject: [PATCH 4/5] Add type test --- xarray/tests/test_cupy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_cupy.py b/xarray/tests/test_cupy.py index f9b4fbe9c28..0276b8ebc08 100644 --- a/xarray/tests/test_cupy.py +++ b/xarray/tests/test_cupy.py @@ -55,4 +55,6 @@ def test_where(): data = cp.zeros(10) - assert where(data < 1, 1, data).all() + output = where(data < 1, 1, data).all() + assert output + assert isinstance(output, cp.ndarray) From f8c5eda6d99fbb9f569ae7e48f95aa41a6455848 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 24 Jul 2020 09:05:52 -0600 Subject: [PATCH 5/5] mypy ignore cupy --- setup.cfg | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.cfg b/setup.cfg index 42dc53bb882..ad0b12a3e32 100644 --- a/setup.cfg +++ b/setup.cfg @@ -138,6 +138,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-cftime.*] ignore_missing_imports = True +[mypy-cupy.*] +ignore_missing_imports = True [mypy-dask.*] ignore_missing_imports = True [mypy-distributed.*]