From 8938d390a969a94275a4d943033a85935acbce2b Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 2 Dec 2022 17:32:40 +0100 Subject: [PATCH] Fix assign_coords resetting all dimension coords to default index (#7347) * fix merge_coords and collect_variables_and_indexes * add test * update what's new * update what's new with pull-request link --- doc/whats-new.rst | 2 ++ xarray/core/merge.py | 12 ++++++------ xarray/tests/test_dataset.py | 14 ++++++++++++++ 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 00dbe80485b..561d1bdabb5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -65,6 +65,8 @@ Bug fixes By `Michael Niklas `_. - Fix multiple reads on fsspec S3 files by resetting file pointer to 0 when reading file streams (:issue:`6813`, :pull:`7304`). By `David Hoese `_ and `Wei Ji Leong `_. +- Fix :py:meth:`Dataset.assign_coords` resetting all dimension coordinates to default (pandas) index (:issue:`7346`, :pull:`7347`). + By `BenoƮt Bovy `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 859b3aeff8f..8a8fc6f1074 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -355,12 +355,12 @@ def append_all(variables, indexes): for name, variable in mapping.items(): if isinstance(variable, DataArray): - coords = variable._coords.copy() # use private API for speed - indexes = dict(variable._indexes) + coords_ = variable._coords.copy() # use private API for speed + indexes_ = dict(variable._indexes) # explicitly overwritten variables should take precedence - coords.pop(name, None) - indexes.pop(name, None) - append_all(coords, indexes) + coords_.pop(name, None) + indexes_.pop(name, None) + append_all(coords_, indexes_) variable = as_variable(variable, name=name) if name in indexes: @@ -561,7 +561,7 @@ def merge_coords( aligned = deep_align( coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value ) - collected = collect_variables_and_indexes(aligned) + collected = collect_variables_and_indexes(aligned, indexes=indexes) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) variables, out_indexes = merge_collected(collected, prioritized, compat=compat) return variables, out_indexes diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 8f3eb728f01..6bf3a2e3aa4 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4170,6 +4170,20 @@ def test_assign_all_multiindex_coords(self) -> None: is not actual.xindexes["level_2"] ) + def test_assign_coords_custom_index_side_effect(self) -> None: + # test that assigning new coordinates do not reset other dimension coord indexes + # to default (pandas) index (https://github.com/pydata/xarray/issues/7346) + class CustomIndex(PandasIndex): + pass + + ds = ( + Dataset(coords={"x": [1, 2, 3]}) + .drop_indexes("x") + .set_xindex("x", CustomIndex) + ) + actual = ds.assign_coords(y=[4, 5, 6]) + assert isinstance(actual.xindexes["x"], CustomIndex) + def test_merge_multiindex_level(self) -> None: data = create_test_multiindex()