Skip to content

Commit

Permalink
Fix assign_coords resetting all dimension coords to default index (#7347
Browse files Browse the repository at this point in the history
)

* fix merge_coords and collect_variables_and_indexes

* add test

* update what's new

* update what's new with pull-request link
  • Loading branch information
benbovy authored Dec 2, 2022
1 parent 92e7cb5 commit 8938d39
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Bug fixes
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Fix multiple reads on fsspec S3 files by resetting file pointer to 0 when reading file streams (:issue:`6813`, :pull:`7304`).
By `David Hoese <https://github.com/djhoese>`_ and `Wei Ji Leong <https://github.com/weiji14>`_.
- Fix :py:meth:`Dataset.assign_coords` resetting all dimension coordinates to default (pandas) index (:issue:`7346`, :pull:`7347`).
By `Benoît Bovy <https://github.com/benbovy>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
12 changes: 6 additions & 6 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,12 @@ def append_all(variables, indexes):

for name, variable in mapping.items():
if isinstance(variable, DataArray):
coords = variable._coords.copy() # use private API for speed
indexes = dict(variable._indexes)
coords_ = variable._coords.copy() # use private API for speed
indexes_ = dict(variable._indexes)
# explicitly overwritten variables should take precedence
coords.pop(name, None)
indexes.pop(name, None)
append_all(coords, indexes)
coords_.pop(name, None)
indexes_.pop(name, None)
append_all(coords_, indexes_)

variable = as_variable(variable, name=name)
if name in indexes:
Expand Down Expand Up @@ -561,7 +561,7 @@ def merge_coords(
aligned = deep_align(
coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value
)
collected = collect_variables_and_indexes(aligned)
collected = collect_variables_and_indexes(aligned, indexes=indexes)
prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat)
variables, out_indexes = merge_collected(collected, prioritized, compat=compat)
return variables, out_indexes
Expand Down
14 changes: 14 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4170,6 +4170,20 @@ def test_assign_all_multiindex_coords(self) -> None:
is not actual.xindexes["level_2"]
)

def test_assign_coords_custom_index_side_effect(self) -> None:
# test that assigning new coordinates do not reset other dimension coord indexes
# to default (pandas) index (https://github.com/pydata/xarray/issues/7346)
class CustomIndex(PandasIndex):
pass

ds = (
Dataset(coords={"x": [1, 2, 3]})
.drop_indexes("x")
.set_xindex("x", CustomIndex)
)
actual = ds.assign_coords(y=[4, 5, 6])
assert isinstance(actual.xindexes["x"], CustomIndex)

def test_merge_multiindex_level(self) -> None:
data = create_test_multiindex()

Expand Down

0 comments on commit 8938d39

Please sign in to comment.