Skip to content

Commit

Permalink
Drop multi-indexes when assigning to a multi-indexed variable (#6798)
Browse files Browse the repository at this point in the history
Co-authored-by: Anderson Banihirwe <[email protected]>
Co-authored-by: Benoit Bovy <[email protected]>
  • Loading branch information
3 people authored Jul 21, 2022
1 parent 9f8d47c commit 4a52799
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 1 deletion.
55 changes: 54 additions & 1 deletion xarray/core/coordinates.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import warnings
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Hashable, Iterator, Mapping, Sequence, cast

import numpy as np
import pandas as pd

from . import formatting
from .indexes import Index, Indexes, assert_no_index_corrupted
from .indexes import Index, Indexes, PandasMultiIndex, assert_no_index_corrupted
from .merge import merge_coordinates_without_align, merge_coords
from .utils import Frozen, ReprObject
from .variable import Variable, calculate_dimensions
Expand Down Expand Up @@ -57,6 +58,9 @@ def variables(self):
def _update_coords(self, coords, indexes):
raise NotImplementedError()

def _maybe_drop_multiindex_coords(self, coords):
raise NotImplementedError()

def __iter__(self) -> Iterator[Hashable]:
# needs to be in the same order as the dataset variables
for k in self.variables:
Expand Down Expand Up @@ -154,6 +158,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:

def update(self, other: Mapping[Any, Any]) -> None:
other_vars = getattr(other, "variables", other)
self._maybe_drop_multiindex_coords(set(other_vars))
coords, indexes = merge_coords(
[self.variables, other_vars], priority_arg=1, indexes=self.xindexes
)
Expand Down Expand Up @@ -304,6 +309,15 @@ def _update_coords(
original_indexes.update(indexes)
self._data._indexes = original_indexes

def _maybe_drop_multiindex_coords(self, coords: set[Hashable]) -> None:
"""Drops variables in coords, and any associated variables as well."""
assert self._data.xindexes is not None
variables, indexes = drop_coords(
coords, self._data._variables, self._data.xindexes
)
self._data._variables = variables
self._data._indexes = indexes

def __delitem__(self, key: Hashable) -> None:
if key in self:
del self._data[key]
Expand Down Expand Up @@ -372,6 +386,14 @@ def _update_coords(
original_indexes.update(indexes)
self._data._indexes = original_indexes

def _maybe_drop_multiindex_coords(self, coords: set[Hashable]) -> None:
"""Drops variables in coords, and any associated variables as well."""
variables, indexes = drop_coords(
coords, self._data._coords, self._data.xindexes
)
self._data._coords = variables
self._data._indexes = indexes

@property
def variables(self):
return Frozen(self._data._coords)
Expand All @@ -397,6 +419,37 @@ def _ipython_key_completions_(self):
return self._data._ipython_key_completions_()


def drop_coords(
coords_to_drop: set[Hashable], variables, indexes: Indexes
) -> tuple[dict, dict]:
"""Drop index variables associated with variables in coords_to_drop."""
# Only warn when we're dropping the dimension with the multi-indexed coordinate
# If asked to drop a subset of the levels in a multi-index, we raise an error
# later but skip the warning here.
new_variables = dict(variables.copy())
new_indexes = dict(indexes.copy())
for key in coords_to_drop & set(indexes):
maybe_midx = indexes[key]
idx_coord_names = set(indexes.get_all_coords(key))
if (
isinstance(maybe_midx, PandasMultiIndex)
and key == maybe_midx.dim
and (idx_coord_names - coords_to_drop)
):
warnings.warn(
f"Updating MultiIndexed coordinate {key!r} would corrupt indices for "
f"other variables: {list(maybe_midx.index.names)!r}. "
f"This will raise an error in the future. Use `.drop_vars({idx_coord_names!r})` before "
"assigning new coordinate values.",
DeprecationWarning,
stacklevel=4,
)
for k in idx_coord_names:
del new_variables[k]
del new_indexes[k]
return new_variables, new_indexes


def assert_coordinate_consistent(
obj: DataArray | Dataset, coords: Mapping[Any, Variable]
) -> None:
Expand Down
1 change: 1 addition & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5764,6 +5764,7 @@ def assign(
data = self.copy()
# do all calculations first...
results: CoercibleMapping = data._calc_assign_results(variables)
data.coords._maybe_drop_multiindex_coords(set(results.keys()))
# ... and then assign
data.update(results)
return data
Expand Down
3 changes: 3 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,9 @@ def dims(self) -> Mapping[Hashable, int]:

return Frozen(self._dims)

def copy(self):
return type(self)(dict(self._indexes), dict(self._variables))

def get_unique(self) -> list[T_PandasOrXarrayIndex]:
"""Return a list of unique indexes, preserving order."""

Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,13 @@ def test_assign_coords(self) -> None:
with pytest.raises(ValueError):
da.coords["x"] = ("y", [1, 2, 3]) # no new dimension to a DataArray

def test_assign_coords_existing_multiindex(self) -> None:
data = self.mda
with pytest.warns(
DeprecationWarning, match=r"Updating MultiIndexed coordinate"
):
data.assign_coords(x=range(4))

def test_coords_alignment(self) -> None:
lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])])
rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])])
Expand Down
12 changes: 12 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3967,6 +3967,18 @@ def test_assign_multiindex_level(self) -> None:
data.assign(level_1=range(4))
data.assign_coords(level_1=range(4))

def test_assign_coords_existing_multiindex(self) -> None:
data = create_test_multiindex()
with pytest.warns(
DeprecationWarning, match=r"Updating MultiIndexed coordinate"
):
data.assign_coords(x=range(4))

with pytest.warns(
DeprecationWarning, match=r"Updating MultiIndexed coordinate"
):
data.assign(x=range(4))

def test_assign_all_multiindex_coords(self) -> None:
data = create_test_multiindex()
actual = data.assign(x=range(4), level_1=range(4), level_2=range(4))
Expand Down

0 comments on commit 4a52799

Please sign in to comment.