Skip to content

Commit

Permalink
sel with categorical index (#3670)
Browse files Browse the repository at this point in the history
* Added a support with categorical index

* fix from_dataframe

* update a test

* Added more tests

* black

* Added a test to make sure raising ValueErrors

* remove unnecessary print

* added a test for reindex

* Fix according to reviews

* blacken

* delete trailing whitespace

Co-authored-by: Deepak Cherian <[email protected]>
Co-authored-by: keewis <[email protected]>
3 people authored Jan 25, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 9c72866 commit cc142f4
Showing 5 changed files with 110 additions and 9 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
@@ -40,6 +40,8 @@ Breaking changes

New Features
~~~~~~~~~~~~
- :py:meth:`DataArray.sel` and :py:meth:`Dataset.sel` now support :py:class:`pandas.CategoricalIndex`. (:issue:`3669`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
- Support using an existing, opened h5netcdf ``File`` with
:py:class:`~xarray.backends.H5NetCDFStore`. This permits creating an
:py:class:`~xarray.Dataset` from a h5netcdf ``File`` that has been opened
22 changes: 13 additions & 9 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
@@ -64,6 +64,7 @@
default_indexes,
isel_variable_and_index,
propagate_indexes,
remove_unused_levels_categories,
roll_index,
)
from .indexing import is_fancy_indexer
@@ -3411,7 +3412,7 @@ def ensure_stackable(val):

def _unstack_once(self, dim: Hashable, fill_value, sparse) -> "Dataset":
index = self.get_index(dim)
index = index.remove_unused_levels()
index = remove_unused_levels_categories(index)
full_idx = pd.MultiIndex.from_product(index.levels, names=index.names)

# take a shortcut in case the MultiIndex was not modified.
@@ -4460,17 +4461,19 @@ def to_dataframe(self):
return self._to_dataframe(self.dims)

def _set_sparse_data_from_dataframe(
self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...]
self, dataframe: pd.DataFrame, dims: tuple
) -> None:
from sparse import COO

idx = dataframe.index
if isinstance(idx, pd.MultiIndex):
coords = np.stack([np.asarray(code) for code in idx.codes], axis=0)
is_sorted = idx.is_lexsorted
shape = tuple(lev.size for lev in idx.levels)
else:
coords = np.arange(idx.size).reshape(1, -1)
is_sorted = True
shape = (idx.size,)

for name, series in dataframe.items():
# Cast to a NumPy array first, in case the Series is a pandas
@@ -4495,14 +4498,16 @@ def _set_sparse_data_from_dataframe(
self[name] = (dims, data)

def _set_numpy_data_from_dataframe(
self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...]
self, dataframe: pd.DataFrame, dims: tuple
) -> None:
idx = dataframe.index
if isinstance(idx, pd.MultiIndex):
# expand the DataFrame to include the product of all levels
full_idx = pd.MultiIndex.from_product(idx.levels, names=idx.names)
dataframe = dataframe.reindex(full_idx)

shape = tuple(lev.size for lev in idx.levels)
else:
shape = (idx.size,)
for name, series in dataframe.items():
data = np.asarray(series).reshape(shape)
self[name] = (dims, data)
@@ -4543,7 +4548,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas
if not dataframe.columns.is_unique:
raise ValueError("cannot convert DataFrame with non-unique columns")

idx = dataframe.index
idx = remove_unused_levels_categories(dataframe.index)
dataframe = dataframe.set_index(idx)
obj = cls()

if isinstance(idx, pd.MultiIndex):
@@ -4553,17 +4559,15 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas
)
for dim, lev in zip(dims, idx.levels):
obj[dim] = (dim, lev)
shape = tuple(lev.size for lev in idx.levels)
else:
index_name = idx.name if idx.name is not None else "index"
dims = (index_name,)
obj[index_name] = (dims, idx)
shape = (idx.size,)

if sparse:
obj._set_sparse_data_from_dataframe(dataframe, dims, shape)
obj._set_sparse_data_from_dataframe(dataframe, dims)
else:
obj._set_numpy_data_from_dataframe(dataframe, dims, shape)
obj._set_numpy_data_from_dataframe(dataframe, dims)
return obj

def to_dask_dataframe(self, dim_order=None, set_index=False):
20 changes: 20 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,26 @@
from .variable import Variable


def remove_unused_levels_categories(index):
"""
Remove unused levels from MultiIndex and unused categories from CategoricalIndex
"""
if isinstance(index, pd.MultiIndex):
index = index.remove_unused_levels()
# if it contains CategoricalIndex, we need to remove unused categories
# manually. See https://github.com/pandas-dev/pandas/issues/30846
if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels):
levels = []
for i, level in enumerate(index.levels):
if isinstance(level, pd.CategoricalIndex):
level = level[index.codes[i]].remove_unused_categories()
levels.append(level)
index = pd.MultiIndex.from_arrays(levels, names=index.names)
elif isinstance(index, pd.CategoricalIndex):
index = index.remove_unused_categories()
return index


class Indexes(collections.abc.Mapping):
"""Immutable proxy for Dataset or DataArrary indexes."""

10 changes: 10 additions & 0 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
@@ -175,6 +175,16 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No
if label.ndim == 0:
if isinstance(index, pd.MultiIndex):
indexer, new_index = index.get_loc_level(label.item(), level=0)
elif isinstance(index, pd.CategoricalIndex):
if method is not None:
raise ValueError(
"'method' is not a valid kwarg when indexing using a CategoricalIndex."
)
if tolerance is not None:
raise ValueError(
"'tolerance' is not a valid kwarg when indexing using a CategoricalIndex."
)
indexer = index.get_loc(label.item())
else:
indexer = index.get_loc(
label.item(), method=method, tolerance=tolerance
65 changes: 65 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1408,6 +1408,56 @@ def test_sel_dataarray_mindex(self):
)
)

def test_sel_categorical(self):
ind = pd.Series(["foo", "bar"], dtype="category")
df = pd.DataFrame({"ind": ind, "values": [1, 2]})
ds = df.set_index("ind").to_xarray()
actual = ds.sel(ind="bar")
expected = ds.isel(ind=1)
assert_identical(expected, actual)

def test_sel_categorical_error(self):
ind = pd.Series(["foo", "bar"], dtype="category")
df = pd.DataFrame({"ind": ind, "values": [1, 2]})
ds = df.set_index("ind").to_xarray()
with pytest.raises(ValueError):
ds.sel(ind="bar", method="nearest")
with pytest.raises(ValueError):
ds.sel(ind="bar", tolerance="nearest")

def test_categorical_index(self):
cat = pd.CategoricalIndex(
["foo", "bar", "foo"],
categories=["foo", "bar", "baz", "qux", "quux", "corge"],
)
ds = xr.Dataset(
{"var": ("cat", np.arange(3))},
coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 1])},
)
# test slice
actual = ds.sel(cat="foo")
expected = ds.isel(cat=[0, 2])
assert_identical(expected, actual)
# make sure the conversion to the array works
actual = ds.sel(cat="foo")["cat"].values
assert (actual == np.array(["foo", "foo"])).all()

ds = ds.set_index(index=["cat", "c"])
actual = ds.unstack("index")
assert actual["var"].shape == (2, 2)

def test_categorical_reindex(self):
cat = pd.CategoricalIndex(
["foo", "bar", "baz"],
categories=["foo", "bar", "baz", "qux", "quux", "corge"],
)
ds = xr.Dataset(
{"var": ("cat", np.arange(3))},
coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 2])},
)
actual = ds.reindex(cat=["foo"])["cat"].values
assert (actual == np.array(["foo"])).all()

def test_sel_drop(self):
data = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]})
expected = Dataset({"foo": 1})
@@ -3865,6 +3915,21 @@ def test_to_and_from_dataframe(self):
expected = pd.DataFrame([[]], index=idx)
assert expected.equals(actual), (expected, actual)

def test_from_dataframe_categorical(self):
cat = pd.CategoricalDtype(
categories=["foo", "bar", "baz", "qux", "quux", "corge"]
)
i1 = pd.Series(["foo", "bar", "foo"], dtype=cat)
i2 = pd.Series(["bar", "bar", "baz"], dtype=cat)

df = pd.DataFrame({"i1": i1, "i2": i2, "values": [1, 2, 3]})
ds = df.set_index("i1").to_xarray()
assert len(ds["i1"]) == 3

ds = df.set_index(["i1", "i2"]).to_xarray()
assert len(ds["i1"]) == 2
assert len(ds["i2"]) == 2

@requires_sparse
def test_from_dataframe_sparse(self):
import sparse

0 comments on commit cc142f4

Please sign in to comment.