Skip to content
This repository has been archived by the owner on Oct 7, 2024. It is now read-only.

Commit

Permalink
Fix set_index when an existing dimension becomes a level (pydata#3520)
Browse files Browse the repository at this point in the history
* Added a test

* Fix set_index

* lint

* black / mypy

* Use _replace method

* whats new
  • Loading branch information
fujiisoup authored Nov 14, 2019
1 parent 8b24037 commit c0ef2f6
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 7 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ New Features

Bug fixes
~~~~~~~~~
- Fix a bug in `set_index` in case that an existing dimension becomes a level variable of MultiIndex. (:pull:`3520`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
- Harmonize `_FillValue`, `missing_value` during encoding and decoding steps. (:pull:`3502`)
By `Anderson Banihirwe <https://github.com/andersy005>`_.
- Fix regression introduced in v0.14.0 that would cause a crash if dask is installed
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
assert_coordinate_consistent,
remap_label_indexers,
)
from .dataset import Dataset, merge_indexes, split_indexes
from .dataset import Dataset, split_indexes
from .formatting import format_item
from .indexes import Indexes, default_indexes
from .merge import PANDAS_TYPES
Expand Down Expand Up @@ -1601,10 +1601,10 @@ def set_index(
--------
DataArray.reset_index
"""
_check_inplace(inplace)
indexes = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index")
coords, _ = merge_indexes(indexes, self._coords, set(), append=append)
return self._replace(coords=coords)
ds = self._to_temp_dataset().set_index(
indexes, append=append, inplace=inplace, **indexes_kwargs
)
return self._from_temp_dataset(ds)

def reset_index(
self,
Expand Down
12 changes: 10 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def merge_indexes(
"""
vars_to_replace: Dict[Hashable, Variable] = {}
vars_to_remove: List[Hashable] = []
dims_to_replace: Dict[Hashable, Hashable] = {}
error_msg = "{} is not the name of an existing variable."

for dim, var_names in indexes.items():
Expand Down Expand Up @@ -244,7 +245,7 @@ def merge_indexes(
if not len(names) and len(var_names) == 1:
idx = pd.Index(variables[var_names[0]].values)

else:
else: # MultiIndex
for n in var_names:
try:
var = variables[n]
Expand All @@ -256,15 +257,22 @@ def merge_indexes(
levels.append(cat.categories)

idx = pd.MultiIndex(levels, codes, names=names)
for n in names:
dims_to_replace[n] = dim

vars_to_replace[dim] = IndexVariable(dim, idx)
vars_to_remove.extend(var_names)

new_variables = {k: v for k, v in variables.items() if k not in vars_to_remove}
new_variables.update(vars_to_replace)

# update dimensions if necessary GH: 3512
for k, v in new_variables.items():
if any(d in dims_to_replace for d in v.dims):
new_dims = [dims_to_replace.get(d, d) for d in v.dims]
new_variables[k] = v._replace(dims=new_dims)
new_coord_names = coord_names | set(vars_to_replace)
new_coord_names -= set(vars_to_remove)

return new_variables, new_coord_names


Expand Down
10 changes: 10 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,16 @@ def test_selection_multiindex_remove_unused(self):
expected = expected.set_index(xy=["x", "y"]).unstack()
assert_identical(expected, actual)

def test_selection_multiindex_from_level(self):
# GH: 3512
da = DataArray([0, 1], dims=["x"], coords={"x": [0, 1], "y": "a"})
db = DataArray([2, 3], dims=["x"], coords={"x": [0, 1], "y": "b"})
data = xr.concat([da, db], dim="x").set_index(xy=["x", "y"])
assert data.dims == ("xy",)
actual = data.sel(y="a")
expected = data.isel(xy=[0, 1]).unstack("xy").squeeze("y").drop("y")
assert_equal(actual, expected)

def test_virtual_default_coords(self):
array = DataArray(np.zeros((5,)), dims="x")
expected = DataArray(range(5), dims="x", name="x")
Expand Down

0 comments on commit c0ef2f6

Please sign in to comment.