From 5ff78b859a263ff2ab85d0f0ede3abf15b91b543 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Thu, 23 Nov 2023 22:30:07 +0100 Subject: [PATCH 1/3] Fix bug for categorical pandas index with categories with EA dtype --- xarray/core/utils.py | 2 ++ xarray/tests/test_dataset.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index ad86b2c7fec..9ba4a43f6d9 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -114,6 +114,8 @@ def get_valid_numpy_dtype(array: np.ndarray | pd.Index): elif hasattr(array, "categories"): # category isn't a real numpy dtype dtype = array.categories.dtype + if not is_valid_numpy_dtype(dtype): + dtype = np.dtype("O") elif not is_valid_numpy_dtype(array.dtype): dtype = np.dtype("O") else: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index ff7703a1cf5..6a1261fcd32 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4697,6 +4697,17 @@ def test_from_dataframe_categorical(self) -> None: assert len(ds["i1"]) == 2 assert len(ds["i2"]) == 2 + def test_from_dataframe_categorical_string_categories(self): + cat = pd.CategoricalIndex( + pd.Categorical.from_codes( + np.array([1, 1, 0, 2]), + categories=pd.Index(["foo", "bar", "baz"], dtype="string"), + ) + ) + ser = pd.Series(1, index=cat) + ds = ser.to_xarray() + assert ds.coords.dtypes["index"] == np.dtype("O") + @requires_sparse def test_from_dataframe_sparse(self) -> None: import sparse From e4452c115bbd19dbb0602bb1ca85a7c5e334cdb7 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Thu, 23 Nov 2023 22:33:14 +0100 Subject: [PATCH 2/3] Add whatsnew --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 76548fe95c5..3a947daf153 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -42,6 +42,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Fix dtype inference for ``pd.CategoricalIndex`` when categories are backed by a ``pd.ExtensionDtype`` (:pull:`8481`) + Documentation ~~~~~~~~~~~~~ From ae99fdd64fc1b1b856d1ff51d3f2fef73cc28039 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 24 Nov 2023 15:06:23 +0100 Subject: [PATCH 3/3] Update xarray/tests/test_dataset.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/tests/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 6a1261fcd32..a53d81e36af 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4697,7 +4697,7 @@ def test_from_dataframe_categorical(self) -> None: assert len(ds["i1"]) == 2 assert len(ds["i2"]) == 2 - def test_from_dataframe_categorical_string_categories(self): + def test_from_dataframe_categorical_string_categories(self) -> None: cat = pd.CategoricalIndex( pd.Categorical.from_codes( np.array([1, 1, 0, 2]),