Skip to content

Commit

Permalink
extract dtypes from underlying duck arrays without coercing to numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
TomNicholas committed Feb 6, 2024
1 parent c9ba2be commit c6f4e3a
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,26 +214,32 @@ def astype(data, dtype, **kwargs):


def asarray(data, xp=np):
print(data)
print(type(data))
return data if is_duck_array(data) else xp.asarray(data)


def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
array_type_cupy = array_type("cupy")
if array_type_cupy and any(
isinstance(x, array_type_cupy) for x in scalars_or_arrays
):
import cupy as cp

arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
def as_duck_array(data, xp=np):
if is_duck_array(data):
return data
elif hasattr(data, "get_duck_array"):
# must be a lazy indexing class wrapping a duck array
return data.get_duck_array()
else:
arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
# Pass arrays directly instead of dtypes to result_type so scalars
# get handled properly.
# Note that result_type() safely gets the dtype from dask arrays without
# evaluating them.
out_type = dtypes.result_type(*arrays)
return [astype(x, out_type, copy=False) for x in arrays]
array_type_cupy = array_type("cupy")
if array_type_cupy and any(isinstance(data, array_type_cupy)):
import cupy as cp

return asarray(data, xp=cp)
else:
return asarray(data, xp=xp)


def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast arrays to a shared dtype using xarray's type promotion rules."""
duckarrays = [as_duck_array(obj, xp=xp) for obj in scalars_or_arrays]
out_type = dtypes.result_type(*duckarrays)
return [astype(x, out_type, copy=False) for x in duckarrays]


def broadcast_to(array, shape):
Expand Down Expand Up @@ -327,7 +333,7 @@ def sum_where(data, axis=None, dtype=None, where=None):
def where(condition, x, y):
"""Three argument where() with better dtype promotion rules."""
xp = get_array_namespace(condition)
return xp.where(condition, *as_shared_dtype([x, y], xp=xp))
return xp.where(condition, *as_shared_dtype([x, y]))


def where_method(data, cond, other=dtypes.NA):
Expand All @@ -350,14 +356,14 @@ def concatenate(arrays, axis=0):
arrays[0], np.ndarray
):
xp = get_array_namespace(arrays[0])
return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis)
return xp.concat(as_shared_dtype(arrays), axis=axis)
return _concatenate(as_shared_dtype(arrays), axis=axis)


def stack(arrays, axis=0):
"""stack() with better dtype promotion rules."""
xp = get_array_namespace(arrays[0])
return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis)
return xp.stack(as_shared_dtype(arrays), axis=axis)


def reshape(array, shape):
Expand Down

0 comments on commit c6f4e3a

Please sign in to comment.