From b649846b9ceef0db8631e7148f5ee9415bdd4621 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 5 Nov 2019 02:19:51 +0000 Subject: [PATCH] Propagate indexes in DataArray binary operations. (#3481) * Propagate indexes in DataArray binary operations. Works by propagating indexes in DataArray._replace. xref #2227. Tests pass! * remove commented code. * fix roll --- xarray/core/dataarray.py | 8 +++++--- xarray/core/dataset.py | 2 ++ xarray/core/groupby.py | 1 + xarray/core/indexes.py | 3 +++ xarray/tests/test_dataarray.py | 11 +++++++++++ xarray/tests/test_dataset.py | 8 ++++++++ 6 files changed, 30 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b61f83bcb1c..35ee90fb5c8 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -386,6 +386,7 @@ def _replace( variable: Variable = None, coords=None, name: Union[Hashable, None, Default] = _default, + indexes=None, ) -> "DataArray": if variable is None: variable = self.variable @@ -393,7 +394,7 @@ def _replace( coords = self._coords if name is _default: name = self.name - return type(self)(variable, coords, name=name, fastpath=True) + return type(self)(variable, coords, name=name, fastpath=True, indexes=indexes) def _replace_maybe_drop_dims( self, variable: Variable, name: Union[Hashable, None, Default] = _default @@ -440,7 +441,8 @@ def _from_temp_dataset( ) -> "DataArray": variable = dataset._variables.pop(_THIS_ARRAY) coords = dataset._variables - return self._replace(variable, coords, name) + indexes = dataset._indexes + return self._replace(variable, coords, name, indexes=indexes) def _to_dataset_split(self, dim: Hashable) -> Dataset: def subset(dim, label): @@ -2506,7 +2508,7 @@ def func(self, other): coords, indexes = self.coords._merge_raw(other_coords) name = self._result_name(other) - return self._replace(variable, coords, name) + return self._replace(variable, coords, name, indexes=indexes) return func diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2b89051e84e..978242e5f6b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4891,6 +4891,8 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): (dim,) = self.variables[k].dims if dim in shifts: indexes[k] = roll_index(v, shifts[dim]) + else: + indexes[k] = v else: indexes = dict(self.indexes) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 353566eb345..209ac14184b 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -529,6 +529,7 @@ def _maybe_unstack(self, obj): for dim in self._inserted_dims: if dim in obj.coords: del obj.coords[dim] + del obj.indexes[dim] return obj def fillna(self, value): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index a96fbccbeee..1574f4f18df 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -35,6 +35,9 @@ def __contains__(self, key): def __getitem__(self, key): return self._indexes[key] + def __delitem__(self, key): + del self._indexes[key] + def __repr__(self): return formatting.indexes_repr(self) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 5114d13b0dc..2c823b0c20a 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3953,6 +3953,17 @@ def test_matmul(self): expected = da.dot(da) assert_identical(result, expected) + def test_binary_op_propagate_indexes(self): + # regression test for GH2227 + self.dv["x"] = np.arange(self.dv.sizes["x"]) + expected = self.dv.indexes["x"] + + actual = (self.dv * 10).indexes["x"] + assert expected is actual + + actual = (self.dv > 10).indexes["x"] + assert expected is actual + def test_binary_op_join_setting(self): dim = "x" align_type = "outer" diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index eab6040e17e..b9fa20fab26 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4951,6 +4951,14 @@ def test_filter_by_attrs(self): ) assert not bool(new_ds.data_vars) + def test_binary_op_propagate_indexes(self): + ds = Dataset( + {"d1": DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]})} + ) + expected = ds.indexes["x"] + actual = (ds * 2).indexes["x"] + assert expected is actual + def test_binary_op_join_setting(self): # arithmetic_join applies to data array coordinates missing_2 = xr.Dataset({"x": [0, 1]})