From c6f4e3afb362b18e79004b58272a7c4028ce68e5 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Tue, 6 Feb 2024 04:33:29 -0500 Subject: [PATCH] extract dtypes from underlying duck arrays without coercing to numpy --- xarray/core/duck_array_ops.py | 44 ++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index b30ba4c3a78..6a6c39f6bcc 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -214,26 +214,32 @@ def astype(data, dtype, **kwargs): def asarray(data, xp=np): + print(data) + print(type(data)) return data if is_duck_array(data) else xp.asarray(data) -def as_shared_dtype(scalars_or_arrays, xp=np): - """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - array_type_cupy = array_type("cupy") - if array_type_cupy and any( - isinstance(x, array_type_cupy) for x in scalars_or_arrays - ): - import cupy as cp - - arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] +def as_duck_array(data, xp=np): + if is_duck_array(data): + return data + elif hasattr(data, "get_duck_array"): + # must be a lazy indexing class wrapping a duck array + return data.get_duck_array() else: - arrays = [asarray(x, xp=xp) for x in scalars_or_arrays] - # Pass arrays directly instead of dtypes to result_type so scalars - # get handled properly. - # Note that result_type() safely gets the dtype from dask arrays without - # evaluating them. - out_type = dtypes.result_type(*arrays) - return [astype(x, out_type, copy=False) for x in arrays] + array_type_cupy = array_type("cupy") + if array_type_cupy and any(isinstance(data, array_type_cupy)): + import cupy as cp + + return asarray(data, xp=cp) + else: + return asarray(data, xp=xp) + + +def as_shared_dtype(scalars_or_arrays, xp=np): + """Cast arrays to a shared dtype using xarray's type promotion rules.""" + duckarrays = [as_duck_array(obj, xp=xp) for obj in scalars_or_arrays] + out_type = dtypes.result_type(*duckarrays) + return [astype(x, out_type, copy=False) for x in duckarrays] def broadcast_to(array, shape): @@ -327,7 +333,7 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition) - return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) + return xp.where(condition, *as_shared_dtype([x, y])) def where_method(data, cond, other=dtypes.NA): @@ -350,14 +356,14 @@ def concatenate(arrays, axis=0): arrays[0], np.ndarray ): xp = get_array_namespace(arrays[0]) - return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) + return xp.concat(as_shared_dtype(arrays), axis=axis) return _concatenate(as_shared_dtype(arrays), axis=axis) def stack(arrays, axis=0): """stack() with better dtype promotion rules.""" xp = get_array_namespace(arrays[0]) - return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis) + return xp.stack(as_shared_dtype(arrays), axis=axis) def reshape(array, shape):