Skip to content
This repository has been archived by the owner on Oct 7, 2024. It is now read-only.

Commit

Permalink
propagate indexes in to_dataset, from_dataset (pydata#3519)
Browse files Browse the repository at this point in the history
* Propagate indexes in _to_dataset, _from_dataset

* Make Indexes immutable again.

* Fix DataArrayGroupby._combine.

* Don't create indexes by default.

* fix to_dataset_Splt

* undo groupby change

* ccomment

* Update xarray/core/indexes.py

* Bad idea to deep copy indexes.

* remove unnecessary copy_indexes calls.

* copy_indexes → propagate_indexes

* more renaming.
  • Loading branch information
dcherian authored Nov 22, 2019
1 parent b0064b2 commit 8aabaf0
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 22 deletions.
44 changes: 31 additions & 13 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
)
from .dataset import Dataset, split_indexes
from .formatting import format_item
from .indexes import Indexes, default_indexes
from .merge import PANDAS_TYPES
from .indexes import Indexes, propagate_indexes, default_indexes
from .merge import PANDAS_TYPES, _extract_indexes_from_coords
from .options import OPTIONS
from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs
from .variable import (
Expand Down Expand Up @@ -367,6 +367,9 @@ def __init__(
data = as_compatible_data(data)
coords, dims = _infer_coords_and_dims(data.shape, coords, dims)
variable = Variable(dims, data, attrs, encoding, fastpath=True)
indexes = dict(
_extract_indexes_from_coords(coords)
) # needed for to_dataset

# These fully describe a DataArray
self._variable = variable
Expand Down Expand Up @@ -400,6 +403,7 @@ def _replace_maybe_drop_dims(
) -> "DataArray":
if variable.dims == self.dims and variable.shape == self.shape:
coords = self._coords.copy()
indexes = self._indexes
elif variable.dims == self.dims:
# Shape has changed (e.g. from reduce(..., keepdims=True)
new_sizes = dict(zip(self.dims, variable.shape))
Expand All @@ -408,12 +412,19 @@ def _replace_maybe_drop_dims(
for k, v in self._coords.items()
if v.shape == tuple(new_sizes[d] for d in v.dims)
}
changed_dims = [
k for k in variable.dims if variable.sizes[k] != self.sizes[k]
]
indexes = propagate_indexes(self._indexes, exclude=changed_dims)
else:
allowed_dims = set(variable.dims)
coords = {
k: v for k, v in self._coords.items() if set(v.dims) <= allowed_dims
}
return self._replace(variable, coords, name)
indexes = propagate_indexes(
self._indexes, exclude=(set(self.dims) - allowed_dims)
)
return self._replace(variable, coords, name, indexes=indexes)

def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray":
if not len(indexes):
Expand Down Expand Up @@ -444,19 +455,21 @@ def _from_temp_dataset(
return self._replace(variable, coords, name, indexes=indexes)

def _to_dataset_split(self, dim: Hashable) -> Dataset:
""" splits dataarray along dimension 'dim' """

def subset(dim, label):
array = self.loc[{dim: label}]
if dim in array.coords:
del array.coords[dim]
array.attrs = {}
return array
return as_variable(array)

variables = {label: subset(dim, label) for label in self.get_index(dim)}

coords = self.coords.to_dataset()
if dim in coords:
del coords[dim]
return Dataset(variables, coords, self.attrs)
variables.update({k: v for k, v in self._coords.items() if k != dim})
indexes = propagate_indexes(self._indexes, exclude=dim)
coord_names = set(self._coords) - set([dim])
dataset = Dataset._from_vars_and_coord_names(
variables, coord_names, indexes=indexes, attrs=self.attrs
)
return dataset

def _to_dataset_whole(
self, name: Hashable = None, shallow_copy: bool = True
Expand All @@ -480,8 +493,12 @@ def _to_dataset_whole(
if shallow_copy:
for k in variables:
variables[k] = variables[k].copy(deep=False)
indexes = self._indexes

coord_names = set(self._coords)
dataset = Dataset._from_vars_and_coord_names(variables, coord_names)
dataset = Dataset._from_vars_and_coord_names(
variables, coord_names, indexes=indexes
)
return dataset

def to_dataset(self, dim: Hashable = None, *, name: Hashable = None) -> Dataset:
Expand Down Expand Up @@ -927,7 +944,8 @@ def copy(self, deep: bool = True, data: Any = None) -> "DataArray":
"""
variable = self.variable.copy(deep=deep, data=data)
coords = {k: v.copy(deep=deep) for k, v in self._coords.items()}
return self._replace(variable, coords)
indexes = self._indexes
return self._replace(variable, coords, indexes=indexes)

def __copy__(self) -> "DataArray":
return self.copy(deep=False)
Expand Down
21 changes: 17 additions & 4 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@
remap_label_indexers,
)
from .duck_array_ops import datetime_to_numeric
from .indexes import Indexes, default_indexes, isel_variable_and_index, roll_index
from .indexes import (
Indexes,
default_indexes,
isel_variable_and_index,
propagate_indexes,
roll_index,
)
from .merge import (
dataset_merge_method,
dataset_update_method,
Expand Down Expand Up @@ -872,8 +878,12 @@ def _construct_direct(
return obj

@classmethod
def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None):
return cls._construct_direct(variables, coord_names, attrs=attrs)
def _from_vars_and_coord_names(
cls, variables, coord_names, indexes=None, attrs=None
):
return cls._construct_direct(
variables, coord_names, indexes=indexes, attrs=attrs
)

def _replace(
self,
Expand Down Expand Up @@ -4375,10 +4385,13 @@ def to_array(self, dim="variable", name=None):

coords = dict(self.coords)
coords[dim] = list(self.data_vars)
indexes = propagate_indexes(self._indexes)

dims = (dim,) + broadcast_vars[0].dims

return DataArray(data, coords, dims, attrs=self.attrs, name=name)
return DataArray(
data, coords, dims, attrs=self.attrs, name=name, indexes=indexes
)

def _to_dataframe(self, ordered_dims):
columns = [k for k in self.variables if k not in self.dims]
Expand Down
6 changes: 4 additions & 2 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .common import ImplementsArrayReduce, ImplementsDatasetReduce
from .concat import concat
from .formatting import format_array_flat
from .indexes import propagate_indexes
from .options import _get_keep_attrs
from .pycompat import integer_types
from .utils import (
Expand Down Expand Up @@ -529,7 +530,7 @@ def _maybe_unstack(self, obj):
for dim in self._inserted_dims:
if dim in obj.coords:
del obj.coords[dim]
del obj.indexes[dim]
obj._indexes = propagate_indexes(obj._indexes, exclude=self._inserted_dims)
return obj

def fillna(self, value):
Expand Down Expand Up @@ -786,7 +787,8 @@ def _combine(self, applied, restore_coord_dims=False, shortcut=False):
combined = self._restore_dim_order(combined)
if coord is not None:
if shortcut:
combined._coords[coord.name] = as_variable(coord)
coord_var = as_variable(coord)
combined._coords[coord.name] = coord_var
else:
combined.coords[coord.name] = coord
combined = self._maybe_restore_empty_groups(combined)
Expand Down
23 changes: 20 additions & 3 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd

from . import formatting
from .utils import is_scalar
from .variable import Variable


Expand Down Expand Up @@ -35,9 +36,6 @@ def __contains__(self, key):
def __getitem__(self, key):
return self._indexes[key]

def __delitem__(self, key):
del self._indexes[key]

def __repr__(self):
return formatting.indexes_repr(self)

Expand Down Expand Up @@ -100,3 +98,22 @@ def roll_index(index: pd.Index, count: int, axis: int = 0) -> pd.Index:
return index[-count:].append(index[:-count])
else:
return index[:]


def propagate_indexes(
indexes: Optional[Dict[Hashable, pd.Index]], exclude: Optional[Any] = None
) -> Optional[Dict[Hashable, pd.Index]]:
""" Creates new indexes dict from existing dict optionally excluding some dimensions.
"""
if exclude is None:
exclude = ()

if is_scalar(exclude):
exclude = (exclude,)

if indexes is not None:
new_indexes = {k: v for k, v in indexes.items() if k not in exclude}
else:
new_indexes = None # type: ignore

return new_indexes
2 changes: 2 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from xarray.convert import from_cdms2
from xarray.core import dtypes
from xarray.core.common import full_like
from xarray.core.indexes import propagate_indexes
from xarray.tests import (
LooseVersion,
ReturnItem,
Expand Down Expand Up @@ -1239,6 +1240,7 @@ def test_coords(self):
assert expected == actual

del da.coords["x"]
da._indexes = propagate_indexes(da._indexes, exclude="x")
expected = DataArray(da.values, {"y": [0, 1, 2]}, dims=["x", "y"], name="foo")
assert_identical(da, expected)

Expand Down

0 comments on commit 8aabaf0

Please sign in to comment.