diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b8fb1f8f58e..abd94779435 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -79,6 +79,8 @@ New Features Bug fixes ~~~~~~~~~ +- Fix a bug in `set_index` in case that an existing dimension becomes a level variable of MultiIndex. (:pull:`3520`) + By `Keisuke Fujii `_. - Harmonize `_FillValue`, `missing_value` during encoding and decoding steps. (:pull:`3502`) By `Anderson Banihirwe `_. - Fix regression introduced in v0.14.0 that would cause a crash if dask is installed diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index a192fe08cee..55e73478260 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -48,7 +48,7 @@ assert_coordinate_consistent, remap_label_indexers, ) -from .dataset import Dataset, merge_indexes, split_indexes +from .dataset import Dataset, split_indexes from .formatting import format_item from .indexes import Indexes, default_indexes from .merge import PANDAS_TYPES @@ -1601,10 +1601,10 @@ def set_index( -------- DataArray.reset_index """ - _check_inplace(inplace) - indexes = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index") - coords, _ = merge_indexes(indexes, self._coords, set(), append=append) - return self._replace(coords=coords) + ds = self._to_temp_dataset().set_index( + indexes, append=append, inplace=inplace, **indexes_kwargs + ) + return self._from_temp_dataset(ds) def reset_index( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 15a7209ab24..de713b830f2 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -204,6 +204,7 @@ def merge_indexes( """ vars_to_replace: Dict[Hashable, Variable] = {} vars_to_remove: List[Hashable] = [] + dims_to_replace: Dict[Hashable, Hashable] = {} error_msg = "{} is not the name of an existing variable." for dim, var_names in indexes.items(): @@ -244,7 +245,7 @@ def merge_indexes( if not len(names) and len(var_names) == 1: idx = pd.Index(variables[var_names[0]].values) - else: + else: # MultiIndex for n in var_names: try: var = variables[n] @@ -256,15 +257,22 @@ def merge_indexes( levels.append(cat.categories) idx = pd.MultiIndex(levels, codes, names=names) + for n in names: + dims_to_replace[n] = dim vars_to_replace[dim] = IndexVariable(dim, idx) vars_to_remove.extend(var_names) new_variables = {k: v for k, v in variables.items() if k not in vars_to_remove} new_variables.update(vars_to_replace) + + # update dimensions if necessary GH: 3512 + for k, v in new_variables.items(): + if any(d in dims_to_replace for d in v.dims): + new_dims = [dims_to_replace.get(d, d) for d in v.dims] + new_variables[k] = v._replace(dims=new_dims) new_coord_names = coord_names | set(vars_to_replace) new_coord_names -= set(vars_to_remove) - return new_variables, new_coord_names diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 7c6dc1825a1..4c3553c867e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1182,6 +1182,16 @@ def test_selection_multiindex_remove_unused(self): expected = expected.set_index(xy=["x", "y"]).unstack() assert_identical(expected, actual) + def test_selection_multiindex_from_level(self): + # GH: 3512 + da = DataArray([0, 1], dims=["x"], coords={"x": [0, 1], "y": "a"}) + db = DataArray([2, 3], dims=["x"], coords={"x": [0, 1], "y": "b"}) + data = xr.concat([da, db], dim="x").set_index(xy=["x", "y"]) + assert data.dims == ("xy",) + actual = data.sel(y="a") + expected = data.isel(xy=[0, 1]).unstack("xy").squeeze("y").drop("y") + assert_equal(actual, expected) + def test_virtual_default_coords(self): array = DataArray(np.zeros((5,)), dims="x") expected = DataArray(range(5), dims="x", name="x")