From 9eb180b7b3df869c21518a17c6fa9eb456551d21 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Thu, 18 Apr 2024 14:52:02 +0200 Subject: [PATCH] (feat): Support for `pandas` `ExtensionArray` (#8723) * (feat): first pass supporting extension arrays * (feat): categorical tests + functionality * (feat): use multiple dispatch for unimplemented ops * (feat): implement (not really) broadcasting * (chore): add more `groupby` tests * (fix): fix more groupby incompatibility * (bug): fix unused categories * (chore): refactor dispatched methods + tests * (fix): shared type should check for extension arrays first and then fall back to numpy * (refactor): tests moved * (chore): more higher level tests * (feat): to/from dataframe * (chore): check for plum import * (fix): `__setitem__`/`__getitem__` * (chore): disallow stacking * (fix): `pyproject.toml` * (fix): `as_shared_type` fix * (chore): add variable tests * (fix): dask + categoricals * (chore): notes/docs * (chore): remove old testing file * (chore): remove ocmmented out code * (fix): import plum dispatch * (refactor): use `is_extension_array_dtype` as much as possible * (refactor): `extension_array`->`array` + move to `indexing` * (refactor): change order of classes * (chore): add small pyarrow test * (fix): fix some mypy issues * (fix): don't register unregisterable method * (fix): appease mypy * (fix): more sensible default implemetations allow most use without `plum` * (fix): handling `pyarrow` tests * (fix): actually do import correctly * (fix): `reduce` condition * (fix): column ordering for dataframes * (refactor): remove encoding business * (refactor): raise error for dask + extension array * (fix): only wrap `ExtensionDuckArray` that has a `.array` which is a pandas extension array * (fix): use duck array equality method, not pandas * (refactor): bye plum! * (fix): `and` to `or` for casting to `ExtensionDuckArray` * (fix): check for class, not type * (fix): only support native endianness * (refactor): no need for superfluous checks in `_maybe_wrap_data` * (chore): clean up docs to no longer reference `plum` * (fix): no longer allow `ExtensionDuckArray` to wrap `ExtensionDuckArray` * (refactor): move `implements` logic to `indexing` * (refactor): `indexing.py` -> `extension_array.py` * (refactor): `ExtensionDuckArray` -> `PandasExtensionArray` * (fix): add writeable property * (fix): don't check writeable for `PandasExtensionArray` * (fix): move check eariler * (refactor): correct guard clause * (chore): remove unnecessary `AttributeError` * (feat): singleton wrapped as array * (feat): remove shared dtype casting * (feat): loop once over `dataframe.items` * (feat): add `__len__` attribute * (fix): ensure constructor recieves `pd.Categorical` * Update xarray/core/extension_array.py Co-authored-by: Deepak Cherian * Update xarray/core/extension_array.py Co-authored-by: Deepak Cherian * (fix): drop condition for categorical corrected * Apply suggestions from code review * (chore): test `chunk` behavior * Update xarray/core/variable.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * (fix): bring back error * (chore): add test for dropping cat for mean * Update whats-new.rst * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Deepak Cherian Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/whats-new.rst | 6 +- properties/test_pandas_roundtrip.py | 4 +- pyproject.toml | 1 + xarray/core/dataset.py | 60 +++++++++--- xarray/core/duck_array_ops.py | 19 +++- xarray/core/extension_array.py | 136 ++++++++++++++++++++++++++++ xarray/core/types.py | 3 + xarray/core/variable.py | 10 ++ xarray/testing/strategies.py | 10 +- xarray/tests/__init__.py | 18 +++- xarray/tests/test_concat.py | 27 +++++- xarray/tests/test_dataset.py | 42 ++++++--- xarray/tests/test_duck_array_ops.py | 108 ++++++++++++++++++++++ xarray/tests/test_groupby.py | 12 ++- xarray/tests/test_merge.py | 2 +- xarray/tests/test_variable.py | 19 ++++ 16 files changed, 434 insertions(+), 43 deletions(-) create mode 100644 xarray/core/extension_array.py diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1c67ba1ee7f..4fe51708646 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,7 +24,11 @@ New Features ~~~~~~~~~~~~ - New "random" method for converting to and from 360_day calendars (:pull:`8603`). By `Pascal Bourgault `_. - +- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array + by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`, + for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray` + then, such as broadcasting. + By `Ilan Gold `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 5c0f14976e6..3d87fcce1d9 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -17,7 +17,9 @@ from hypothesis import given # isort:skip numeric_dtypes = st.one_of( - npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes() + npst.unsigned_integer_dtypes(endianness="="), + npst.integer_dtypes(endianness="="), + npst.floating_dtypes(endianness="="), ) numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt)) diff --git a/pyproject.toml b/pyproject.toml index 15a3c99eec2..8cbd395b2a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,6 +130,7 @@ module = [ "opt_einsum.*", "pandas.*", "pooch.*", + "pyarrow.*", "pydap.*", "pytest.*", "scipy.*", diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4c80a47209e..96f3be00995 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -24,6 +24,7 @@ from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload import numpy as np +from pandas.api.types import is_extension_array_dtype # remove once numpy 2.0 is the oldest supported version try: @@ -6852,10 +6853,13 @@ def reduce( if ( # Some reduction functions (e.g. std, var) need to run on variables # that don't have the reduce dims: PR5393 - not reduce_dims - or not numeric_only - or np.issubdtype(var.dtype, np.number) - or (var.dtype == np.bool_) + not is_extension_array_dtype(var.dtype) + and ( + not reduce_dims + or not numeric_only + or np.issubdtype(var.dtype, np.number) + or (var.dtype == np.bool_) + ) ): # prefer to aggregate over axis=None rather than # axis=(0, 1) if they will be equivalent, because @@ -7168,13 +7172,37 @@ def to_pandas(self) -> pd.Series | pd.DataFrame: ) def _to_dataframe(self, ordered_dims: Mapping[Any, int]): - columns = [k for k in self.variables if k not in self.dims] + columns_in_order = [k for k in self.variables if k not in self.dims] + non_extension_array_columns = [ + k + for k in columns_in_order + if not is_extension_array_dtype(self.variables[k].data) + ] + extension_array_columns = [ + k + for k in columns_in_order + if is_extension_array_dtype(self.variables[k].data) + ] data = [ self._variables[k].set_dims(ordered_dims).values.reshape(-1) - for k in columns + for k in non_extension_array_columns ] index = self.coords.to_index([*ordered_dims]) - return pd.DataFrame(dict(zip(columns, data)), index=index) + broadcasted_df = pd.DataFrame( + dict(zip(non_extension_array_columns, data)), index=index + ) + for extension_array_column in extension_array_columns: + extension_array = self.variables[extension_array_column].data.array + index = self[self.variables[extension_array_column].dims[0]].data + extension_array_df = pd.DataFrame( + {extension_array_column: extension_array}, + index=self[self.variables[extension_array_column].dims[0]].data, + ) + extension_array_df.index.name = self.variables[extension_array_column].dims[ + 0 + ] + broadcasted_df = broadcasted_df.join(extension_array_df) + return broadcasted_df[columns_in_order] def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame: """Convert this dataset into a pandas.DataFrame. @@ -7321,11 +7349,13 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: "cannot convert a DataFrame with a non-unique MultiIndex into xarray" ) - # Cast to a NumPy array first, in case the Series is a pandas Extension - # array (which doesn't have a valid NumPy dtype) - # TODO: allow users to control how this casting happens, e.g., by - # forwarding arguments to pandas.Series.to_numpy? - arrays = [(k, np.asarray(v)) for k, v in dataframe.items()] + arrays = [] + extension_arrays = [] + for k, v in dataframe.items(): + if not is_extension_array_dtype(v): + arrays.append((k, np.asarray(v))) + else: + extension_arrays.append((k, v)) indexes: dict[Hashable, Index] = {} index_vars: dict[Hashable, Variable] = {} @@ -7339,6 +7369,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: xr_idx = PandasIndex(lev, dim) indexes[dim] = xr_idx index_vars.update(xr_idx.create_variables()) + arrays += [(k, np.asarray(v)) for k, v in extension_arrays] + extension_arrays = [] else: index_name = idx.name if idx.name is not None else "index" dims = (index_name,) @@ -7352,7 +7384,9 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: obj._set_sparse_data_from_dataframe(idx, arrays, dims) else: obj._set_numpy_data_from_dataframe(idx, arrays, dims) - return obj + for name, extension_array in extension_arrays: + obj[name] = (dims, extension_array) + return obj[dataframe.columns] if len(dataframe.columns) else obj def to_dask_dataframe( self, dim_order: Sequence[Hashable] | None = None, set_index: bool = False diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index ef497e78ebf..d95dfa566cc 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -32,6 +32,7 @@ from numpy import concatenate as _concatenate from numpy.lib.stride_tricks import sliding_window_view # noqa from packaging.version import Version +from pandas.api.types import is_extension_array_dtype from xarray.core import dask_array_ops, dtypes, nputils from xarray.core.options import OPTIONS @@ -156,7 +157,7 @@ def isnull(data): return full_like(data, dtype=bool, fill_value=False) else: # at this point, array should have dtype=object - if isinstance(data, np.ndarray): + if isinstance(data, np.ndarray) or is_extension_array_dtype(data): return pandas_isnull(data) else: # Not reachable yet, but intended for use with other duck array @@ -221,9 +222,19 @@ def asarray(data, xp=np): def as_shared_dtype(scalars_or_arrays, xp=np): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - array_type_cupy = array_type("cupy") - if array_type_cupy and any( - isinstance(x, array_type_cupy) for x in scalars_or_arrays + if any(is_extension_array_dtype(x) for x in scalars_or_arrays): + extension_array_types = [ + x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) + ] + if len(extension_array_types) == len(scalars_or_arrays) and all( + isinstance(x, type(extension_array_types[0])) for x in extension_array_types + ): + return scalars_or_arrays + raise ValueError( + f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}" + ) + elif array_type_cupy := array_type("cupy") and any( # noqa: F841 + isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821 ): import cupy as cp diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py new file mode 100644 index 00000000000..6521e425615 --- /dev/null +++ b/xarray/core/extension_array.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Callable, Generic + +import numpy as np +import pandas as pd +from pandas.api.types import is_extension_array_dtype + +from xarray.core.types import DTypeLikeSave, T_ExtensionArray + +HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} + + +def implements(numpy_function): + """Register an __array_function__ implementation for MyArray objects.""" + + def decorator(func): + HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func + return func + + return decorator + + +@implements(np.issubdtype) +def __extension_duck_array__issubdtype( + extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave +) -> bool: + return False # never want a function to think a pandas extension dtype is a subtype of numpy + + +@implements(np.broadcast_to) +def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): + if shape[0] == len(arr) and len(shape) == 1: + return arr + raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") + + +@implements(np.stack) +def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): + raise NotImplementedError("Cannot stack 1d-only pandas categorical array.") + + +@implements(np.concatenate) +def __extension_duck_array__concatenate( + arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None +) -> T_ExtensionArray: + return type(arrays[0])._concat_same_type(arrays) + + +@implements(np.where) +def __extension_duck_array__where( + condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray +) -> T_ExtensionArray: + if ( + isinstance(x, pd.Categorical) + and isinstance(y, pd.Categorical) + and x.dtype != y.dtype + ): + x = x.add_categories(set(y.categories).difference(set(x.categories))) + y = y.add_categories(set(x.categories).difference(set(y.categories))) + return pd.Series(x).where(condition, pd.Series(y)).array + + +class PandasExtensionArray(Generic[T_ExtensionArray]): + array: T_ExtensionArray + + def __init__(self, array: T_ExtensionArray): + """NEP-18 compliant wrapper for pandas extension arrays. + + Parameters + ---------- + array : T_ExtensionArray + The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. + ``` + """ + if not isinstance(array, pd.api.extensions.ExtensionArray): + raise TypeError(f"{array} is not an pandas ExtensionArray.") + self.array = array + + def __array_function__(self, func, types, args, kwargs): + def replace_duck_with_extension_array(args) -> list: + args_as_list = list(args) + for index, value in enumerate(args_as_list): + if isinstance(value, PandasExtensionArray): + args_as_list[index] = value.array + elif isinstance( + value, tuple + ): # should handle more than just tuple? iterable? + args_as_list[index] = tuple( + replace_duck_with_extension_array(value) + ) + elif isinstance(value, list): + args_as_list[index] = replace_duck_with_extension_array(value) + return args_as_list + + args = tuple(replace_duck_with_extension_array(args)) + if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: + return func(*args, **kwargs) + res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) + if is_extension_array_dtype(res): + return type(self)[type(res)](res) + return res + + def __array_ufunc__(ufunc, method, *inputs, **kwargs): + return ufunc(*inputs, **kwargs) + + def __repr__(self): + return f"{type(self)}(array={repr(self.array)})" + + def __getattr__(self, attr: str) -> object: + return getattr(self.array, attr) + + def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: + item = self.array[key] + if is_extension_array_dtype(item): + return type(self)(item) + if np.isscalar(item): + return type(self)(type(self.array)([item])) + return item + + def __setitem__(self, key, val): + self.array[key] = val + + def __eq__(self, other): + if np.isscalar(other): + other = type(self)(type(self.array)([other])) + if isinstance(other, PandasExtensionArray): + return self.array == other.array + return self.array == other + + def __ne__(self, other): + return ~(self == other) + + def __len__(self): + return len(self.array) diff --git a/xarray/core/types.py b/xarray/core/types.py index 410cf3de00b..8f58e54d8cf 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -167,6 +167,9 @@ def copy( # hopefully in the future we can narrow this down more: T_DuckArray = TypeVar("T_DuckArray", bound=Any, covariant=True) +# For typing pandas extension arrays. +T_ExtensionArray = TypeVar("T_ExtensionArray", bound=pd.api.extensions.ExtensionArray) + ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] VarCompatible = Union["Variable", "ScalarOrArray"] diff --git a/xarray/core/variable.py b/xarray/core/variable.py index e89cf95411c..2229eaa2d24 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -13,11 +13,13 @@ import numpy as np import pandas as pd from numpy.typing import ArrayLike +from pandas.api.types import is_extension_array_dtype import xarray as xr # only for Dataset and DataArray from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from xarray.core.arithmetic import VariableArithmetic from xarray.core.common import AbstractArray +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ( BasicIndexer, OuterIndexer, @@ -47,6 +49,7 @@ NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( indexing.ExplicitlyIndexed, pd.Index, + pd.api.extensions.ExtensionArray, ) # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) @@ -184,6 +187,8 @@ def _maybe_wrap_data(data): """ if isinstance(data, pd.Index): return PandasIndexingAdapter(data) + if isinstance(data, pd.api.extensions.ExtensionArray): + return PandasExtensionArray[type(data)](data) return data @@ -2570,6 +2575,11 @@ def chunk( # type: ignore[override] dask.array.from_array """ + if is_extension_array_dtype(self): + raise ValueError( + f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first." + ) + if from_array_kwargs is None: from_array_kwargs = {} diff --git a/xarray/testing/strategies.py b/xarray/testing/strategies.py index d2503dfd535..449d0c793cc 100644 --- a/xarray/testing/strategies.py +++ b/xarray/testing/strategies.py @@ -45,7 +45,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]: Generates only those numpy dtypes which xarray can handle. Use instead of hypothesis.extra.numpy.scalar_dtypes in order to exclude weirder dtypes such as unicode, byte_string, array, or nested dtypes. - Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows. + Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows. Checks only native endianness. Requires the hypothesis package to be installed. @@ -56,10 +56,10 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]: # TODO should this be exposed publicly? # We should at least decide what the set of numpy dtypes that xarray officially supports is. return ( - npst.integer_dtypes() - | npst.unsigned_integer_dtypes() - | npst.floating_dtypes() - | npst.complex_number_dtypes() + npst.integer_dtypes(endianness="=") + | npst.unsigned_integer_dtypes(endianness="=") + | npst.floating_dtypes(endianness="=") + | npst.complex_number_dtypes(endianness="=") # | npst.datetime64_dtypes() # | npst.timedelta64_dtypes() # | npst.unicode_string_dtypes() diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 5007db9eeb2..3ce788dfb7f 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -18,6 +18,7 @@ from xarray import Dataset from xarray.core import utils from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401 +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ExplicitlyIndexed from xarray.core.options import set_options from xarray.core.variable import IndexVariable @@ -52,7 +53,9 @@ def assert_writeable(ds): readonly = [ name for name, var in ds.variables.items() - if not isinstance(var, IndexVariable) and not var.data.flags.writeable + if not isinstance(var, IndexVariable) + and not isinstance(var.data, PandasExtensionArray) + and not var.data.flags.writeable ] assert not readonly, readonly @@ -112,6 +115,7 @@ def _importorskip( has_fsspec, requires_fsspec = _importorskip("fsspec") has_iris, requires_iris = _importorskip("iris") has_numbagg, requires_numbagg = _importorskip("numbagg", "0.4.0") +has_pyarrow, requires_pyarrow = _importorskip("pyarrow") with warnings.catch_warnings(): warnings.filterwarnings( "ignore", @@ -307,6 +311,7 @@ def create_test_data( seed: int | None = None, add_attrs: bool = True, dim_sizes: tuple[int, int, int] = _DEFAULT_TEST_DIM_SIZES, + use_extension_array: bool = False, ) -> Dataset: rs = np.random.RandomState(seed) _vars = { @@ -329,7 +334,16 @@ def create_test_data( obj[v] = (dims, data) if add_attrs: obj[v].attrs = {"foo": "variable"} - + if use_extension_array: + obj["var4"] = ( + "dim1", + pd.Categorical( + np.random.choice( + list(string.ascii_lowercase[: np.random.randint(5)]), + size=dim_sizes[0], + ) + ), + ) if dim_sizes == _DEFAULT_TEST_DIM_SIZES: numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") else: diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 0cf4cc03a09..1ddb5a569bd 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -152,6 +152,21 @@ def test_concat_missing_var() -> None: assert_identical(actual, expected) +def test_concat_categorical() -> None: + data1 = create_test_data(use_extension_array=True) + data2 = create_test_data(use_extension_array=True) + concatenated = concat([data1, data2], dim="dim1") + assert ( + concatenated["var4"] + == type(data2["var4"].variable.data.array)._concat_same_type( + [ + data1["var4"].variable.data.array, + data2["var4"].variable.data.array, + ] + ) + ).all() + + def test_concat_missing_multiple_consecutive_var() -> None: datasets = create_concat_datasets(3, seed=123) expected = concat(datasets, dim="day") @@ -451,8 +466,11 @@ def test_concat_fill_missing_variables( class TestConcatDataset: @pytest.fixture - def data(self) -> Dataset: - return create_test_data().drop_dims("dim3") + def data(self, request) -> Dataset: + use_extension_array = request.param if hasattr(request, "param") else False + return create_test_data(use_extension_array=use_extension_array).drop_dims( + "dim3" + ) def rectify_dim_order(self, data, dataset) -> Dataset: # return a new dataset with all variable dimensions transposed into @@ -464,7 +482,9 @@ def rectify_dim_order(self, data, dataset) -> Dataset: ) @pytest.mark.parametrize("coords", ["different", "minimal"]) - @pytest.mark.parametrize("dim", ["dim1", "dim2"]) + @pytest.mark.parametrize( + "dim,data", [["dim1", True], ["dim2", False]], indirect=["data"] + ) def test_concat_simple(self, data, dim, coords) -> None: datasets = [g for _, g in data.groupby(dim, squeeze=False)] assert_identical(data, concat(datasets, dim, coords=coords)) @@ -492,6 +512,7 @@ def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: expected = data.copy().assign(foo=(["dim1", "bar"], foo)) assert_identical(expected, actual) + @pytest.mark.parametrize("data", [False], indirect=["data"]) def test_concat_2(self, data) -> None: dim = "dim2" datasets = [g.squeeze(dim) for _, g in data.groupby(dim, squeeze=False)] diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index e2a64964775..a948fafc815 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4614,10 +4614,12 @@ def test_to_and_from_dataframe(self) -> None: x = np.random.randn(10) y = np.random.randn(10) t = list("abcdefghij") - ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t)}) + cat = pd.Categorical(["a", "b"] * 5) + ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t), "cat": ("t", cat)}) expected = pd.DataFrame( np.array([x, y]).T, columns=["a", "b"], index=pd.Index(t, name="t") ) + expected["cat"] = cat actual = ds.to_dataframe() # use the .equals method to check all DataFrame metadata assert expected.equals(actual), (expected, actual) @@ -4628,23 +4630,31 @@ def test_to_and_from_dataframe(self) -> None: # check roundtrip assert_identical(ds, Dataset.from_dataframe(actual)) - + assert isinstance(ds["cat"].variable.data.dtype, pd.CategoricalDtype) # test a case with a MultiIndex w = np.random.randn(2, 3) - ds = Dataset({"w": (("x", "y"), w)}) + cat = pd.Categorical(["a", "a", "c"]) + ds = Dataset({"w": (("x", "y"), w), "cat": ("y", cat)}) ds["y"] = ("y", list("abc")) exp_index = pd.MultiIndex.from_arrays( [[0, 0, 0, 1, 1, 1], ["a", "b", "c", "a", "b", "c"]], names=["x", "y"] ) - expected = pd.DataFrame(w.reshape(-1), columns=["w"], index=exp_index) + expected = pd.DataFrame( + {"w": w.reshape(-1), "cat": pd.Categorical(["a", "a", "c", "a", "a", "c"])}, + index=exp_index, + ) actual = ds.to_dataframe() assert expected.equals(actual) # check roundtrip + # from_dataframe attempts to broadcast across because it doesn't know better, so cat must be converted + ds["cat"] = (("x", "y"), np.stack((ds["cat"].to_numpy(), ds["cat"].to_numpy()))) assert_identical(ds.assign_coords(x=[0, 1]), Dataset.from_dataframe(actual)) # Check multiindex reordering new_order = ["x", "y"] + # revert broadcasting fix above for 1d arrays + ds["cat"] = ("y", cat) actual = ds.to_dataframe(dim_order=new_order) assert expected.equals(actual) @@ -4653,7 +4663,11 @@ def test_to_and_from_dataframe(self) -> None: [["a", "a", "b", "b", "c", "c"], [0, 1, 0, 1, 0, 1]], names=["y", "x"] ) expected = pd.DataFrame( - w.transpose().reshape(-1), columns=["w"], index=exp_index + { + "w": w.transpose().reshape(-1), + "cat": pd.Categorical(["a", "a", "a", "a", "c", "c"]), + }, + index=exp_index, ) actual = ds.to_dataframe(dim_order=new_order) assert expected.equals(actual) @@ -4706,7 +4720,7 @@ def test_to_and_from_dataframe(self) -> None: expected = pd.DataFrame([[]], index=idx) assert expected.equals(actual), (expected, actual) - def test_from_dataframe_categorical(self) -> None: + def test_from_dataframe_categorical_index(self) -> None: cat = pd.CategoricalDtype( categories=["foo", "bar", "baz", "qux", "quux", "corge"] ) @@ -4721,7 +4735,7 @@ def test_from_dataframe_categorical(self) -> None: assert len(ds["i1"]) == 2 assert len(ds["i2"]) == 2 - def test_from_dataframe_categorical_string_categories(self) -> None: + def test_from_dataframe_categorical_index_string_categories(self) -> None: cat = pd.CategoricalIndex( pd.Categorical.from_codes( np.array([1, 1, 0, 2]), @@ -5449,18 +5463,22 @@ def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None: assert list(actual) == expected def test_reduce_non_numeric(self) -> None: - data1 = create_test_data(seed=44) + data1 = create_test_data(seed=44, use_extension_array=True) data2 = create_test_data(seed=44) - add_vars = {"var4": ["dim1", "dim2"], "var5": ["dim1"]} + add_vars = {"var5": ["dim1", "dim2"], "var6": ["dim1"]} for v, dims in sorted(add_vars.items()): size = tuple(data1.sizes[d] for d in dims) data = np.random.randint(0, 100, size=size).astype(np.str_) data1[v] = (dims, data, {"foo": "variable"}) - - assert "var4" not in data1.mean() and "var5" not in data1.mean() + # var4 is extension array categorical and should be dropped + assert ( + "var4" not in data1.mean() + and "var5" not in data1.mean() + and "var6" not in data1.mean() + ) assert_equal(data1.mean(), data2.mean()) assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1")) - assert "var4" not in data1.mean(dim="dim2") and "var5" in data1.mean(dim="dim2") + assert "var5" not in data1.mean(dim="dim2") and "var6" in data1.mean(dim="dim2") @pytest.mark.filterwarnings( "ignore:Once the behaviour of DataArray:DeprecationWarning" diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index df1ab1f40f9..26821c69495 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -27,6 +27,7 @@ timedelta_to_numeric, where, ) +from xarray.core.extension_array import PandasExtensionArray from xarray.namedarray.pycompat import array_type from xarray.testing import assert_allclose, assert_equal, assert_identical from xarray.tests import ( @@ -38,11 +39,55 @@ requires_bottleneck, requires_cftime, requires_dask, + requires_pyarrow, ) dask_array_type = array_type("dask") +@pytest.fixture +def categorical1(): + return pd.Categorical(["cat1", "cat2", "cat2", "cat1", "cat2"]) + + +@pytest.fixture +def categorical2(): + return pd.Categorical(["cat2", "cat1", "cat2", "cat3", "cat1"]) + + +try: + import pyarrow as pa + + @pytest.fixture + def arrow1(): + return pd.arrays.ArrowExtensionArray( + pa.array([{"x": 1, "y": True}, {"x": 2, "y": False}]) + ) + + @pytest.fixture + def arrow2(): + return pd.arrays.ArrowExtensionArray( + pa.array([{"x": 3, "y": False}, {"x": 4, "y": True}]) + ) + +except ImportError: + pass + + +@pytest.fixture +def int1(): + return pd.arrays.IntegerArray( + np.array([1, 2, 3, 4, 5]), np.array([True, False, False, True, True]) + ) + + +@pytest.fixture +def int2(): + return pd.arrays.IntegerArray( + np.array([6, 7, 8, 9, 10]), np.array([True, True, False, True, False]) + ) + + class TestOps: @pytest.fixture(autouse=True) def setUp(self): @@ -119,6 +164,51 @@ def test_where_type_promotion(self): assert result.dtype == np.float32 assert_array_equal(result, np.array([1, np.nan], dtype=np.float32)) + def test_where_extension_duck_array(self, categorical1, categorical2): + where_res = where( + np.array([True, False, True, False, False]), + PandasExtensionArray(categorical1), + PandasExtensionArray(categorical2), + ) + assert isinstance(where_res, PandasExtensionArray) + assert ( + where_res == pd.Categorical(["cat1", "cat1", "cat2", "cat3", "cat1"]) + ).all() + + def test_concatenate_extension_duck_array(self, categorical1, categorical2): + concate_res = concatenate( + [PandasExtensionArray(categorical1), PandasExtensionArray(categorical2)] + ) + assert isinstance(concate_res, PandasExtensionArray) + assert ( + concate_res + == type(categorical1)._concat_same_type((categorical1, categorical2)) + ).all() + + @requires_pyarrow + def test_duck_extension_array_pyarrow_concatenate(self, arrow1, arrow2): + concatenated = concatenate( + (PandasExtensionArray(arrow1), PandasExtensionArray(arrow2)) + ) + assert concatenated[2]["x"] == 3 + assert concatenated[3]["y"] + + def test___getitem__extension_duck_array(self, categorical1): + extension_duck_array = PandasExtensionArray(categorical1) + assert (extension_duck_array[0:2] == categorical1[0:2]).all() + assert isinstance(extension_duck_array[0:2], PandasExtensionArray) + assert extension_duck_array[0] == categorical1[0] + assert isinstance(extension_duck_array[0], PandasExtensionArray) + mask = [True, False, True, False, True] + assert (extension_duck_array[mask] == categorical1[mask]).all() + + def test__setitem__extension_duck_array(self, categorical1): + extension_duck_array = PandasExtensionArray(categorical1) + extension_duck_array[2] = "cat1" # already existing category + assert extension_duck_array[2] == "cat1" + with pytest.raises(TypeError, match="Cannot setitem on a Categorical"): + extension_duck_array[2] = "cat4" # new category + def test_stack_type_promotion(self): result = stack([1, "b"]) assert_array_equal(result, np.array([1, "b"], dtype=object)) @@ -932,3 +1022,21 @@ def test_push_dask(): dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=n ) np.testing.assert_equal(actual, expected) + + +def test_duck_extension_array_equality(categorical1, int1): + int_duck_array = PandasExtensionArray(int1) + categorical_duck_array = PandasExtensionArray(categorical1) + assert (int_duck_array != categorical_duck_array).all() + assert (categorical_duck_array == categorical1).all() + assert (int_duck_array[0:2] == int1[0:2]).all() + + +def test_duck_extension_array_repr(int1): + int_duck_array = PandasExtensionArray(int1) + assert repr(int1) in repr(int_duck_array) + + +def test_duck_extension_array_attr(int1): + int_duck_array = PandasExtensionArray(int1) + assert (~int_duck_array.fillna(10)).all() diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index df9c40ca6f4..afe4d669628 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -35,6 +35,7 @@ def dataset() -> xr.Dataset: { "foo": (("x", "y", "z"), np.random.randn(3, 4, 2)), "baz": ("x", ["e", "f", "g"]), + "cat": ("y", pd.Categorical(["cat1", "cat2", "cat2", "cat1"])), }, {"x": ("x", ["a", "b", "c"], {"name": "x"}), "y": [1, 2, 3, 4], "z": [1, 2]}, ) @@ -79,6 +80,7 @@ def test_groupby_dims_property(dataset, recwarn) -> None: ) assert len(recwarn) == 0 + dataset = dataset.drop_vars(["cat"]) stacked = dataset.stack({"xy": ("x", "y")}) assert tuple(stacked.groupby("xy", squeeze=False).dims) == tuple( stacked.isel(xy=[0]).dims @@ -91,7 +93,7 @@ def test_groupby_sizes_property(dataset) -> None: assert dataset.groupby("x").sizes == dataset.isel(x=1).sizes with pytest.warns(UserWarning, match="The `squeeze` kwarg"): assert dataset.groupby("y").sizes == dataset.isel(y=1).sizes - + dataset = dataset.drop("cat") stacked = dataset.stack({"xy": ("x", "y")}) with pytest.warns(UserWarning, match="The `squeeze` kwarg"): assert stacked.groupby("xy").sizes == stacked.isel(xy=0).sizes @@ -760,6 +762,8 @@ def test_groupby_getitem(dataset) -> None: assert_identical(dataset.foo.sel(x="a"), dataset.foo.groupby("x")["a"]) with pytest.warns(UserWarning, match="The `squeeze` kwarg"): assert_identical(dataset.foo.sel(z=1), dataset.foo.groupby("z")[1]) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(dataset.cat.sel(y=1), dataset.cat.groupby("y")[1]) assert_identical(dataset.sel(x=["a"]), dataset.groupby("x", squeeze=False)["a"]) assert_identical(dataset.sel(z=[1]), dataset.groupby("z", squeeze=False)[1]) @@ -769,6 +773,12 @@ def test_groupby_getitem(dataset) -> None: ) assert_identical(dataset.foo.sel(z=[1]), dataset.foo.groupby("z", squeeze=False)[1]) + assert_identical(dataset.cat.sel(y=[1]), dataset.cat.groupby("y", squeeze=False)[1]) + with pytest.raises( + NotImplementedError, match="Cannot broadcast 1d-only pandas categorical array." + ): + dataset.groupby("boo", squeeze=False) + dataset = dataset.drop_vars(["cat"]) actual = ( dataset.groupby("boo", squeeze=False)["f"].unstack().transpose("x", "y", "z") ) diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index c6597d5abb0..52935e9714e 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -37,7 +37,7 @@ def test_merge_arrays(self): assert_identical(actual, expected) def test_merge_datasets(self): - data = create_test_data(add_attrs=False) + data = create_test_data(add_attrs=False, use_extension_array=True) actual = xr.merge([data[["var1"]], data[["var2"]]]) expected = data[["var1", "var2"]] diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index d9289aa6674..8a9345e74d4 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1576,6 +1576,20 @@ def test_transpose_0d(self): actual = variable.transpose() assert_identical(actual, variable) + def test_pandas_cateogrical_dtype(self): + data = pd.Categorical(np.arange(10, dtype="int64")) + v = self.cls("x", data) + print(v) # should not error + assert pd.api.types.is_extension_array_dtype(v.dtype) + + def test_pandas_cateogrical_no_chunk(self): + data = pd.Categorical(np.arange(10, dtype="int64")) + v = self.cls("x", data) + with pytest.raises( + ValueError, match=r".*was found to be a Pandas ExtensionArray.*" + ): + v.chunk((5,)) + def test_squeeze(self): v = Variable(["x", "y"], [[1]]) assert_identical(Variable([], 1), v.squeeze()) @@ -2373,6 +2387,11 @@ def test_multiindex(self): def test_pad(self, mode, xr_arg, np_arg): super().test_pad(mode, xr_arg, np_arg) + def test_pandas_cateogrical_dtype(self): + data = pd.Categorical(np.arange(10, dtype="int64")) + with pytest.raises(ValueError, match="was found to be a Pandas ExtensionArray"): + self.cls("x", data) + @requires_sparse class TestVariableWithSparse: