diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a8302715317..d92f3239f60 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -42,6 +42,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Fix dtype inference for ``pd.CategoricalIndex`` when categories are backed by a ``pd.ExtensionDtype`` (:pull:`8481`) + Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/utils.py b/xarray/core/utils.py index ad86b2c7fec..9ba4a43f6d9 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -114,6 +114,8 @@ def get_valid_numpy_dtype(array: np.ndarray | pd.Index): elif hasattr(array, "categories"): # category isn't a real numpy dtype dtype = array.categories.dtype + if not is_valid_numpy_dtype(dtype): + dtype = np.dtype("O") elif not is_valid_numpy_dtype(array.dtype): dtype = np.dtype("O") else: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index ff7703a1cf5..a53d81e36af 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4697,6 +4697,17 @@ def test_from_dataframe_categorical(self) -> None: assert len(ds["i1"]) == 2 assert len(ds["i2"]) == 2 + def test_from_dataframe_categorical_string_categories(self) -> None: + cat = pd.CategoricalIndex( + pd.Categorical.from_codes( + np.array([1, 1, 0, 2]), + categories=pd.Index(["foo", "bar", "baz"], dtype="string"), + ) + ) + ser = pd.Series(1, index=cat) + ds = ser.to_xarray() + assert ds.coords.dtypes["index"] == np.dtype("O") + @requires_sparse def test_from_dataframe_sparse(self) -> None: import sparse