From e677b7a0aa344faee3eb407e63422038c2029399 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 6 Feb 2019 08:07:38 -0800 Subject: [PATCH] Refactor (part of) dataset.py to use explicit indexes (#2696) * Refactor (part of) dataset.py to use explicit indexes * Use copy.copy() * Ensure coordinate order is deterministic --- xarray/core/alignment.py | 30 ++- xarray/core/dataset.py | 419 ++++++++++++++++++++++----------- xarray/core/duck_array_ops.py | 2 +- xarray/core/indexes.py | 47 +++- xarray/core/merge.py | 38 +-- xarray/core/variable.py | 4 +- xarray/tests/test_dataarray.py | 4 +- xarray/tests/test_dataset.py | 2 +- 8 files changed, 372 insertions(+), 174 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index c44a0c4201d..7aaeff00b5e 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -3,13 +3,15 @@ import warnings from collections import OrderedDict, defaultdict from contextlib import suppress +from typing import Any, Mapping, Optional import numpy as np +import pandas as pd from . import utils from .indexing import get_indexer_nd from .utils import is_dict_like, is_full_slice -from .variable import IndexVariable +from .variable import IndexVariable, Variable def _get_joiner(join): @@ -260,8 +262,15 @@ def reindex_like_indexers(target, other): return indexers -def reindex_variables(variables, sizes, indexes, indexers, method=None, - tolerance=None, copy=True): +def reindex_variables( + variables: Mapping[Any, Variable], + sizes: Mapping[Any, int], + indexes: Mapping[Any, pd.Index], + indexers: Mapping, + method: Optional[str] = None, + tolerance: Any = None, + copy: bool = True, +) -> 'Tuple[OrderedDict[Any, Variable], OrderedDict[Any, pd.Index]]': """Conform a dictionary of aligned variables onto a new set of variables, filling in missing values with NaN. @@ -274,7 +283,7 @@ def reindex_variables(variables, sizes, indexes, indexers, method=None, sizes : dict-like Dictionary from dimension names to integer sizes. indexes : dict-like - Dictionary of xarray.IndexVariable objects associated with variables. + Dictionary of indexes associated with variables. indexers : dict Dictionary with keys given by dimension names and values given by arrays of coordinates tick labels. Any mis-matched coordinate values @@ -300,13 +309,15 @@ def reindex_variables(variables, sizes, indexes, indexers, method=None, Returns ------- reindexed : OrderedDict - Another dict, with the items in variables but replaced indexes. + Dict of reindexed variables. + new_indexes : OrderedDict + Dict of indexes associated with the reindexed variables. """ from .dataarray import DataArray # build up indexers for assignment along each dimension int_indexers = {} - targets = {} + targets = OrderedDict() masked_dims = set() unchanged_dims = set() @@ -359,7 +370,7 @@ def reindex_variables(variables, sizes, indexes, indexers, method=None, if dim in variables: var = variables[dim] - args = (var.attrs, var.encoding) + args = (var.attrs, var.encoding) # type: tuple else: args = () reindexed[dim] = IndexVariable((dim,), indexers[dim], *args) @@ -384,7 +395,10 @@ def reindex_variables(variables, sizes, indexes, indexers, method=None, reindexed[name] = new_var - return reindexed + new_indexes = OrderedDict(indexes) + new_indexes.update(targets) + + return reindexed, new_indexes def broadcast(*args, **kwargs): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 8863dedb7db..d1323c171eb 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1,3 +1,4 @@ +import copy import functools import sys import warnings @@ -5,7 +6,10 @@ from collections.abc import Mapping from distutils.version import LooseVersion from numbers import Number -from typing import Any, Dict, List, Set, Tuple, Union +from typing import ( + Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, TYPE_CHECKING, + Union, +) import numpy as np import pandas as pd @@ -22,8 +26,9 @@ _contains_datetime_like_objects) from .coordinates import ( DatasetCoordinates, LevelCoordinatesSource, assert_coordinate_consistent, - remap_label_indexers) -from .indexes import Indexes, default_indexes + remap_label_indexers, +) +from .indexes import Indexes, default_indexes, isel_variable_and_index from .merge import ( dataset_merge_method, dataset_update_method, merge_data_and_coords, merge_variables) @@ -34,6 +39,9 @@ decode_numpy_dict_values, either_dict_or_kwargs, hashable, maybe_wrap_array) from .variable import IndexVariable, Variable, as_variable, broadcast_variables +if TYPE_CHECKING: + from .dataarray import DataArray + # list of attributes of pd.DatetimeIndex that are ndarrays of time info _DATETIMEINDEX_COMPONENTS = ['year', 'month', 'day', 'hour', 'minute', @@ -305,6 +313,9 @@ def __getitem__(self, key): return self.dataset.sel(**key) +T = TypeVar('T', bound='Dataset') + + class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): """A multi-dimensional, in memory, array database. @@ -350,7 +361,6 @@ def __init__(self, data_vars=None, coords=None, attrs=None, Global attributes to save on this dataset. compat : deprecated """ - if compat is not None: warnings.warn( 'The `compat` argument to Dataset is deprecated and will be ' @@ -359,10 +369,11 @@ def __init__(self, data_vars=None, coords=None, attrs=None, FutureWarning, stacklevel=2) else: compat = 'broadcast_equals' - self._variables = OrderedDict() + + self._variables = OrderedDict() # type: OrderedDict[Any, Variable] self._coord_names = set() - self._dims = {} - self._attrs = None + self._dims = {} # type: Dict[Any, int] + self._attrs = None # type: Optional[OrderedDict] self._file_obj = None if data_vars is None: data_vars = {} @@ -410,7 +421,7 @@ def load_store(cls, store, decoder=None): return obj @property - def variables(self): + def variables(self) -> 'Mapping[Any, Variable]': """Low level interface to Dataset contents as dict of Variable objects. This ordered dictionary is frozen to prevent mutation that could @@ -420,11 +431,8 @@ def variables(self): """ return Frozen(self._variables) - def _attrs_copy(self): - return None if self._attrs is None else OrderedDict(self._attrs) - @property - def attrs(self): + def attrs(self) -> Mapping: """Dictionary of global attributes on this dataset """ if self._attrs is None: @@ -436,7 +444,7 @@ def attrs(self, value): self._attrs = OrderedDict(value) @property - def encoding(self): + def encoding(self) -> Dict: """Dictionary of global encoding attributes on this dataset """ if self._encoding is None: @@ -448,7 +456,7 @@ def encoding(self, value): self._encoding = dict(value) @property - def dims(self): + def dims(self) -> 'Mapping[Any, int]': """Mapping from dimension names to lengths. Cannot be modified directly, but is updated when adding new variables. @@ -460,7 +468,7 @@ def dims(self): return Frozen(SortedKeysDict(self._dims)) @property - def sizes(self): + def sizes(self) -> 'Mapping[Any, int]': """Mapping from dimension names to lengths. Cannot be modified directly, but is updated when adding new variables. @@ -474,7 +482,7 @@ def sizes(self): """ return self.dims - def load(self, **kwargs): + def load(self: T, **kwargs) -> T: """Manually trigger loading of this dataset's data from disk or a remote source into memory and return this dataset. @@ -549,18 +557,32 @@ def __dask_postcompute__(self): info = [(True, k, v.__dask_postcompute__()) if dask.is_dask_collection(v) else (False, k, v) for k, v in self._variables.items()] - return self._dask_postcompute, (info, self._coord_names, self._dims, - self._attrs, self._file_obj, - self._encoding) + args = ( + info, + self._coord_names, + self._dims, + self._attrs, + self._indexes, + self._encoding, + self._file_obj, + ) + return self._dask_postcompute, args def __dask_postpersist__(self): import dask info = [(True, k, v.__dask_postpersist__()) if dask.is_dask_collection(v) else (False, k, v) for k, v in self._variables.items()] - return self._dask_postpersist, (info, self._coord_names, self._dims, - self._attrs, self._file_obj, - self._encoding) + args = ( + info, + self._coord_names, + self._dims, + self._attrs, + self._indexes, + self._encoding, + self._file_obj, + ) + return self._dask_postpersist, args @staticmethod def _dask_postcompute(results, info, *args): @@ -591,7 +613,7 @@ def _dask_postpersist(dsk, info, *args): return Dataset._construct_direct(variables, *args) - def compute(self, **kwargs): + def compute(self: T, **kwargs) -> T: """Manually trigger loading of this dataset's data from disk or a remote source into memory and return a new dataset. The original is left unaltered. @@ -629,7 +651,7 @@ def _persist_inplace(self, **kwargs): return self - def persist(self, **kwargs): + def persist(self: T, **kwargs) -> T: """ Trigger computation, keeping data as dask arrays This operation can be used to trigger computation on underlying dask @@ -651,8 +673,8 @@ def persist(self, **kwargs): return new._persist_inplace(**kwargs) @classmethod - def _construct_direct(cls, variables, coord_names, dims=None, attrs=None, - indexes=None, file_obj=None, encoding=None): + def _construct_direct(cls, variables, coord_names, dims, attrs=None, + indexes=None, encoding=None, file_obj=None): """Shortcut around __init__ for internal use when we want to skip costly validation """ @@ -667,62 +689,103 @@ def _construct_direct(cls, variables, coord_names, dims=None, attrs=None, obj._initialized = True return obj - __default_attrs = object() + __default = object() @classmethod def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None): dims = dict(calculate_dimensions(variables)) return cls._construct_direct(variables, coord_names, dims, attrs) - def _replace_vars_and_dims(self, variables, coord_names=None, dims=None, - attrs=__default_attrs, indexes=None, - inplace=False): + def _replace( + self: T, + variables: 'OrderedDict[Any, Variable]' = None, + coord_names: set = None, + dims: 'OrderedDict[Any, int]' = None, + attrs: 'Optional[OrderedDict]' = __default, + indexes: 'Optional[OrderedDict[Any, pd.Index]]' = __default, + encoding: Optional[dict] = __default, + inplace: bool = False, + ) -> T: """Fastpath constructor for internal use. - Preserves coord names and attributes. If not provided explicitly, - dimensions are recalculated from the supplied variables. - - The arguments are *not* copied when placed on the new dataset. It is up - to the caller to ensure that they have the right type and are not used - elsewhere. - - Parameters - ---------- - variables : OrderedDict - coord_names : set or None, optional - attrs : OrderedDict or None, optional + Returns an object with optionally with replaced attributes. - Returns - ------- - new : Dataset + Explicitly passed arguments are *not* copied when placed on the new + dataset. It is up to the caller to ensure that they have the right type + and are not used elsewhere. """ - if dims is None: - dims = calculate_dimensions(variables) if inplace: - self._dims = dims - self._variables = variables + if variables is not None: + self._variables = variables if coord_names is not None: self._coord_names = coord_names - if attrs is not self.__default_attrs: + if dims is not None: + self._dims = dims + if attrs is not self.__default: self._attrs = attrs - self._indexes = indexes + if indexes is not self.__default: + self._indexes = indexes + if encoding is not self.__default: + self._encoding = encoding obj = self else: + if variables is None: + variables = self._variables.copy() if coord_names is None: coord_names = self._coord_names.copy() - if attrs is self.__default_attrs: - attrs = self._attrs_copy() + if dims is None: + dims = self._dims.copy() + if attrs is self.__default: + attrs = copy.copy(self._attrs) + if indexes is self.__default: + indexes = copy.copy(self._indexes) + if encoding is self.__default: + encoding = copy.copy(self._encoding) obj = self._construct_direct( - variables, coord_names, dims, attrs, indexes) + variables, coord_names, dims, attrs, indexes, encoding) return obj - def _replace_indexes(self, indexes): - if not len(indexes): + def _replace_with_new_dims( + self: T, + variables: 'OrderedDict[Any, Variable]' = None, + coord_names: set = None, + attrs: 'Optional[OrderedDict]' = __default, + indexes: 'Optional[OrderedDict[Any, pd.Index]]' = __default, + inplace: bool = False, + ) -> T: + """Replace variables with recalculated dimensions.""" + dims = dict(calculate_dimensions(variables)) + return self._replace( + variables, coord_names, dims, attrs, indexes, inplace=inplace) + + def _replace_vars_and_dims( + self: T, + variables: 'OrderedDict[Any, Variable]' = None, + coord_names: set = None, + dims: 'OrderedDict[Any, int]' = None, + attrs: 'Optional[OrderedDict]' = __default, + inplace: bool = False, + ) -> T: + """Deprecated version of _replace_with_new_dims(). + + Unlike _replace_with_new_dims(), this method always recalculates + indexes from variables. + """ + if dims is None: + dims = calculate_dimensions(variables) + return self._replace( + variables, coord_names, dims, attrs, indexes=None, inplace=inplace) + + def _overwrite_indexes(self, indexes): + if not indexes: return self + variables = self._variables.copy() + new_indexes = OrderedDict(self.indexes) for name, idx in indexes.items(): variables[name] = IndexVariable(name, idx) - obj = self._replace_vars_and_dims(variables) + new_indexes[name] = idx + obj = self._replace(variables, indexes=new_indexes) # switch from dimension to level names, if necessary dim_names = {} @@ -733,7 +796,7 @@ def _replace_indexes(self, indexes): obj = obj.rename(dim_names) return obj - def copy(self, deep=False, data=None): + def copy(self: T, deep: bool = False, data: Mapping = None) -> T: """Returns a copy of this dataset. If `deep=True`, a deep copy is made of each of the component variables. @@ -849,22 +912,7 @@ def copy(self, deep=False, data=None): variables = OrderedDict((k, v.copy(deep=deep, data=data.get(k))) for k, v in self._variables.items()) - # skip __init__ to avoid costly validation - return self._construct_direct(variables, self._coord_names.copy(), - self._dims.copy(), self._attrs_copy(), - encoding=self.encoding) - - def _subset_with_all_valid_coords(self, variables, coord_names, attrs): - needed_dims = set() - for v in variables.values(): - needed_dims.update(v.dims) - for k in self._coord_names: - if set(self.variables[k].dims) <= needed_dims: - variables[k] = self._variables[k] - coord_names.add(k) - dims = dict((k, self._dims[k]) for k in needed_dims) - - return self._construct_direct(variables, coord_names, dims, attrs) + return self._replace(variables) @property def _level_coords(self): @@ -872,16 +920,14 @@ def _level_coords(self): coordinate name. """ level_coords = OrderedDict() - for cname in self._coord_names: - var = self.variables[cname] - if var.ndim == 1 and isinstance(var, IndexVariable): - level_names = var.level_names - if level_names is not None: - dim, = var.dims - level_coords.update({lname: dim for lname in level_names}) + for name, index in self.indexes.items(): + if isinstance(index, pd.MultiIndex): + level_names = index.names + (dim,) = self.variables[name].dims + level_coords.update({lname: dim for lname in level_names}) return level_coords - def _copy_listed(self, names): + def _copy_listed(self: T, names) -> T: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. """ @@ -898,10 +944,26 @@ def _copy_listed(self, names): if ref_name in self._coord_names or ref_name in self.dims: coord_names.add(var_name) - return self._subset_with_all_valid_coords(variables, coord_names, - attrs=self.attrs.copy()) + needed_dims = set() # type: set + for v in variables.values(): + needed_dims.update(v.dims) + + dims = dict((k, self.dims[k]) for k in needed_dims) - def _construct_dataarray(self, name): + for k in self._coord_names: + if set(self.variables[k].dims) <= needed_dims: + variables[k] = self._variables[k] + coord_names.add(k) + + if self._indexes is None: + indexes = None + else: + indexes = OrderedDict((k, v) for k, v in self._indexes.items() + if k in coord_names) + + return self._replace(variables, coord_names, dims, indexes=indexes) + + def _construct_dataarray(self, name) -> 'DataArray': """Construct a DataArray by indexing this dataset """ from .dataarray import DataArray @@ -912,13 +974,21 @@ def _construct_dataarray(self, name): _, name, variable = _get_virtual_variable( self._variables, name, self._level_coords, self.dims) - coords = OrderedDict() needed_dims = set(variable.dims) + + coords = OrderedDict() for k in self.coords: if set(self.variables[k].dims) <= needed_dims: coords[k] = self.variables[k] - return DataArray(variable, coords, name=name, fastpath=True) + if self._indexes is None: + indexes = None + else: + indexes = OrderedDict((k, v) for k, v in self._indexes.items() + if k in coords) + + return DataArray(variable, coords, name=name, indexes=indexes, + fastpath=True) def __copy__(self): return self.copy(deep=False) @@ -1078,7 +1148,7 @@ def identical(self, other): return False @property - def indexes(self): + def indexes(self) -> 'Mapping[Any, pd.Index]': """Mapping of pandas.Index objects used for label based indexing """ if self._indexes is None: @@ -1410,9 +1480,11 @@ def maybe_chunk(name, var, chunks): variables = OrderedDict([(k, maybe_chunk(k, v, chunks)) for k, v in self.variables.items()]) - return self._replace_vars_and_dims(variables) + return self._replace(variables) - def _validate_indexers(self, indexers): + def _validate_indexers( + self, indexers: Mapping, + ) -> List[Tuple[Any, Union[slice, Variable]]]: """ Here we make sure + indexer has a valid keys + indexer is in a valid data type @@ -1457,7 +1529,7 @@ def _validate_indexers(self, indexers): indexers_list.append((k, v)) return indexers_list - def _get_indexers_coordinates(self, indexers): + def _get_indexers_coords_and_indexes(self, indexers): """ Extract coordinates from indexers. Returns an OrderedDict mapping from coordinate name to the coordinate variable. @@ -1468,6 +1540,7 @@ def _get_indexers_coordinates(self, indexers): from .dataarray import DataArray coord_list = [] + indexes = OrderedDict() for k, v in indexers.items(): if isinstance(v, DataArray): v_coords = v.coords @@ -1482,17 +1555,22 @@ def _get_indexers_coordinates(self, indexers): v_coords = v[v.values.nonzero()[0]].coords coord_list.append({d: v_coords[d].variable for d in v.coords}) + indexes.update(v.indexes) - # we don't need to call align() explicitly, because merge_variables - # already checks for exact alignment between dimension coordinates + # we don't need to call align() explicitly or check indexes for + # alignment, because merge_variables already checks for exact alignment + # between dimension coordinates coords = merge_variables(coord_list) assert_coordinate_consistent(self, coords) - attached_coords = OrderedDict() - for k, v in coords.items(): # silently drop the conflicted variables. - if k not in self._variables: - attached_coords[k] = v - return attached_coords + # silently drop the conflicted variables. + attached_coords = OrderedDict( + (k, v) for k, v in coords.items() if k not in self._variables + ) + attached_indexes = OrderedDict( + (k, v) for k, v in indexes.items() if k not in self._variables + ) + return attached_coords, attached_indexes def isel(self, indexers=None, drop=False, **indexers_kwargs): """Returns a new dataset with each array indexed along the specified @@ -1540,23 +1618,36 @@ def isel(self, indexers=None, drop=False, **indexers_kwargs): indexers_list = self._validate_indexers(indexers) variables = OrderedDict() - for name, var in self._variables.items(): + indexes = OrderedDict() + for name, var in self.variables.items(): var_indexers = {k: v for k, v in indexers_list if k in var.dims} - new_var = var.isel(indexers=var_indexers) - if not (drop and name in var_indexers): - variables[name] = new_var + if drop and name in var_indexers: + continue # drop this variable + + if name in self.indexes: + new_var, new_index = isel_variable_and_index( + var, self.indexes[name], var_indexers) + if new_index is not None: + indexes[name] = new_index + else: + new_var = var.isel(indexers=var_indexers) + + variables[name] = new_var coord_names = set(variables).intersection(self._coord_names) - selected = self._replace_vars_and_dims(variables, - coord_names=coord_names) + selected = self._replace_with_new_dims( + variables, coord_names, indexes) # Extract coordinates from indexers - coord_vars = selected._get_indexers_coordinates(indexers) + coord_vars, new_indexes = ( + selected._get_indexers_coords_and_indexes(indexers)) variables.update(coord_vars) + indexes.update(new_indexes) coord_names = (set(variables) .intersection(self._coord_names) .union(coord_vars)) - return self._replace_vars_and_dims(variables, coord_names=coord_names) + return self._replace_with_new_dims( + variables, coord_names, indexes=indexes) def sel(self, indexers=None, method=None, tolerance=None, drop=False, **indexers_kwargs): @@ -1626,7 +1717,7 @@ def sel(self, indexers=None, method=None, tolerance=None, drop=False, pos_indexers, new_indexes = remap_label_indexers( self, indexers=indexers, method=method, tolerance=tolerance) result = self.isel(indexers=pos_indexers, drop=drop) - return result._replace_indexes(new_indexes) + return result._overwrite_indexes(new_indexes) def isel_points(self, dim='points', **indexers): # type: (...) -> Dataset @@ -1926,12 +2017,13 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, if bad_dims: raise ValueError('invalid reindex dimensions: %s' % bad_dims) - variables = alignment.reindex_variables( + variables, indexes = alignment.reindex_variables( self.variables, self.sizes, self.indexes, indexers, method, tolerance, copy=copy) coord_names = set(self._coord_names) coord_names.update(indexers) - return self._replace_vars_and_dims(variables, coord_names) + return self._replace_with_new_dims( + variables, coord_names, indexes=indexes) def interp(self, coords=None, method='linear', assume_sorted=False, kwargs={}, **coords_kwargs): @@ -2005,9 +2097,11 @@ def _validate_interp_indexer(x, new_x): for name, var in obj._variables.items(): if name not in indexers: if var.dtype.kind in 'uifc': - var_indexers = {k: _validate_interp_indexer( - maybe_variable(obj, k), v) for k, v - in indexers.items() if k in var.dims} + var_indexers = { + k: _validate_interp_indexer(maybe_variable(obj, k), v) + for k, v in indexers.items() + if k in var.dims + } variables[name] = missing.interp( var, var_indexers, method, **kwargs) elif all(d not in indexers for d in var.dims): @@ -2015,17 +2109,23 @@ def _validate_interp_indexer(x, new_x): variables[name] = var coord_names = set(variables).intersection(obj._coord_names) - selected = obj._replace_vars_and_dims(variables, - coord_names=coord_names) + indexes = OrderedDict( + (k, v) for k, v in obj.indexes.items() if k not in indexers) + selected = self._replace_with_new_dims( + variables, coord_names, indexes=indexes) + # attach indexer as coordinate variables.update(indexers) # Extract coordinates from indexers - coord_vars = selected._get_indexers_coordinates(coords) + coord_vars, new_indexes = ( + selected._get_indexers_coords_and_indexes(coords)) variables.update(coord_vars) + indexes.update(new_indexes) coord_names = (set(variables) .intersection(obj._coord_names) .union(coord_vars)) - return obj._replace_vars_and_dims(variables, coord_names=coord_names) + return self._replace_with_new_dims( + variables, coord_names, indexes=indexes) def interp_like(self, other, method='linear', assume_sorted=False, kwargs={}): @@ -2084,6 +2184,46 @@ def interp_like(self, other, method='linear', assume_sorted=False, ds = self.reindex(object_coords) return ds.interp(numeric_coords, method, assume_sorted, kwargs) + # Helper methods for rename() + def _rename_vars(self, name_dict, dims_dict): + variables = OrderedDict() + coord_names = set() + for k, v in self.variables.items(): + name = name_dict.get(k, k) + dims = tuple(dims_dict.get(dim, dim) for dim in v.dims) + var = v.copy(deep=False) + var.dims = dims + if name in variables: + raise ValueError('the new name %r conflicts' % (name,)) + variables[name] = var + if k in self._coord_names: + coord_names.add(name) + return variables, coord_names + + def _rename_dims(self, dims_dict): + return {dims_dict.get(k, k): v for k, v in self.dims.items()} + + def _rename_indexes(self, name_dict): + if self._indexes is None: + return None + indexes = OrderedDict() + for k, v in self.indexes.items(): + new_name = name_dict.get(k, k) + if isinstance(v, pd.MultiIndex): + new_names = [name_dict.get(k, k) for k in v.names] + index = pd.MultiIndex(v.levels, v.labels, v.sortorder, + names=new_names, verify_integrity=False) + else: + index = pd.Index(v, name=new_name) + indexes[new_name] = index + return indexes + + def _rename_all(self, name_dict, dim_dict): + variables, coord_names = self._rename_vars(name_dict, dim_dict) + dims = self._rename_dims(dim_dict) + indexes = self._rename_indexes(name_dict) + return variables, coord_names, dims, indexes + def rename(self, name_dict=None, inplace=None, **names): """Returns a new object with renamed variables and dimensions. @@ -2109,6 +2249,7 @@ def rename(self, name_dict=None, inplace=None, **names): Dataset.swap_dims DataArray.rename """ + # TODO: add separate rename_vars and rename_dims methods. inplace = _check_inplace(inplace) name_dict = either_dict_or_kwargs(name_dict, names, 'rename') for k, v in name_dict.items(): @@ -2116,24 +2257,10 @@ def rename(self, name_dict=None, inplace=None, **names): raise ValueError("cannot rename %r because it is not a " "variable or dimension in this dataset" % k) - variables = OrderedDict() - coord_names = set() - for k, v in self._variables.items(): - name = name_dict.get(k, k) - dims = tuple(name_dict.get(dim, dim) for dim in v.dims) - var = v.copy(deep=False) - var.dims = dims - if name in variables: - raise ValueError('the new name %r conflicts' % (name,)) - variables[name] = var - if k in self._coord_names: - coord_names.add(name) - - dims = OrderedDict((name_dict.get(k, k), v) - for k, v in self.dims.items()) - - return self._replace_vars_and_dims(variables, coord_names, dims=dims, - inplace=inplace) + variables, coord_names, dims, indexes = self._rename_all( + name_dict=name_dict, dim_dict=name_dict) + return self._replace(variables, coord_names, dims=dims, + indexes=indexes, inplace=inplace) def swap_dims(self, dims_dict, inplace=None): """Returns a new object with swapped dimensions. @@ -2159,6 +2286,8 @@ def swap_dims(self, dims_dict, inplace=None): Dataset.rename DataArray.swap_dims """ + # TODO: deprecate this method in favor of a (less confusing) + # rename_dims() method that only renames dimensions. inplace = _check_inplace(inplace) for k, v in dims_dict.items(): if k not in self.dims: @@ -2171,11 +2300,10 @@ def swap_dims(self, dims_dict, inplace=None): result_dims = set(dims_dict.get(dim, dim) for dim in self.dims) - variables = OrderedDict() - coord_names = self._coord_names.copy() coord_names.update(dims_dict.values()) + variables = OrderedDict() for k, v in self.variables.items(): dims = tuple(dims_dict.get(dim, dim) for dim in v.dims) if k in result_dims: @@ -2185,8 +2313,17 @@ def swap_dims(self, dims_dict, inplace=None): var.dims = dims variables[k] = var - return self._replace_vars_and_dims(variables, coord_names, - inplace=inplace) + indexes = OrderedDict() + for k, v in self.indexes.items(): + if k in dims_dict: + new_name = dims_dict[k] + new_index = variables[k].to_index() + indexes[new_name] = new_index + else: + indexes[k] = v + + return self._replace_with_new_dims(variables, coord_names, + indexes=indexes, inplace=inplace) def expand_dims(self, dim, axis=None): """Return a new object with an additional axis (or axes) inserted at @@ -2270,7 +2407,11 @@ def expand_dims(self, dim, axis=None): # it will be promoted to a 1D coordinate with a single value. variables[k] = v.set_dims(k) - return self._replace_vars_and_dims(variables, self._coord_names) + new_dims = self._dims.copy() + for d in dim: + new_dims[d] = 1 + + return self._replace(variables, dims=new_dims) def set_index(self, indexes=None, append=False, inplace=None, **indexes_kwargs): diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 330bdc19cfc..36c4090297d 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -302,7 +302,7 @@ def mean(array, axis=None, skipna=None, **kwargs): return _mean(array, axis=axis, skipna=skipna, **kwargs) -mean.numeric_only = True +mean.numeric_only = True # type: ignore def _nd_cum_func(cum_func, array, axis, **kwargs): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index c360a209c46..6d8b553036a 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1,10 +1,14 @@ -from collections.abc import Mapping +import collections.abc from collections import OrderedDict +from typing import Any, Iterable, Mapping, Optional, Tuple, Union + +import pandas as pd from . import formatting +from .variable import Variable -class Indexes(Mapping): +class Indexes(collections.abc.Mapping): """Immutable proxy for Dataset or DataArrary indexes.""" def __init__(self, indexes): """Not for public consumption. @@ -32,7 +36,10 @@ def __repr__(self): return formatting.indexes_repr(self) -def default_indexes(coords, dims): +def default_indexes( + coords: Mapping[Any, Variable], + dims: Iterable, +) -> 'OrderedDict[Any, pd.Index]': """Default indexes for a Dataset/DataArray. Parameters @@ -44,8 +51,38 @@ def default_indexes(coords, dims): Returns ------- - Mapping[Any, pandas.Index] mapping indexing keys (levels/dimension names) - to indexes used for indexing along that dimension. + Mapping from indexing keys (levels/dimension names) to indexes used for + indexing along that dimension. """ return OrderedDict((key, coords[key].to_index()) for key in dims if key in coords) + + +def isel_variable_and_index( + variable: Variable, + index: pd.Index, + indexers: Mapping[Any, Union[slice, Variable]], +) -> Tuple[Variable, Optional[pd.Index]]: + """Index a Variable and pandas.Index together.""" + if not indexers: + # nothing to index + return variable.copy(deep=False), index + + if len(variable.dims) > 1: + raise NotImplementedError( + 'indexing multi-dimensional variable with indexes is not ' + 'supported yet') + + new_variable = variable.isel(indexers) + + if new_variable.ndim != 1: + # can't preserve a index if result is not 0D + return new_variable, None + + # we need to compute the new index + (dim,) = variable.dims + indexer = indexers[dim] + if isinstance(indexer, Variable): + indexer = indexer.data + new_index = index[indexer] + return new_variable, new_index diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 7bbd14470f2..daf400765d5 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -175,8 +175,9 @@ def merge_variables( return merged -def expand_variable_dicts(list_of_variable_dicts): - # type: (List[Union[Dataset, Dict]]) -> List[Dict[Any, Variable]] +def expand_variable_dicts( + list_of_variable_dicts: 'List[Union[Dataset, OrderedDict]]', +) -> 'List[OrderedDict[Any, Variable]]': """Given a list of dicts with xarray object values, expand the values. Parameters @@ -201,22 +202,23 @@ def expand_variable_dicts(list_of_variable_dicts): for variables in list_of_variable_dicts: if isinstance(variables, Dataset): - sanitized_vars = variables.variables - else: - # append coords to var_dicts before appending sanitized_vars, - # because we want coords to appear first - sanitized_vars = OrderedDict() + var_dicts.append(variables.variables) + continue - for name, var in variables.items(): - if isinstance(var, DataArray): - # use private API for speed - coords = var._coords.copy() - # explicitly overwritten variables should take precedence - coords.pop(name, None) - var_dicts.append(coords) + # append coords to var_dicts before appending sanitized_vars, + # because we want coords to appear first + sanitized_vars = OrderedDict() # type: OrderedDict[Any, Variable] + + for name, var in variables.items(): + if isinstance(var, DataArray): + # use private API for speed + coords = var._coords.copy() + # explicitly overwritten variables should take precedence + coords.pop(name, None) + var_dicts.append(coords) - var = as_variable(var, name=name) - sanitized_vars[name] = var + var = as_variable(var, name=name) + sanitized_vars[name] = var var_dicts.append(sanitized_vars) @@ -526,7 +528,9 @@ def merge(objects, compat='no_conflicts', join='outer'): for obj in objects] variables, coord_names, dims = merge_core(dict_like_objects, compat, join) - merged = Dataset._construct_direct(variables, coord_names, dims) + # TODO: don't always recompute indexes + merged = Dataset._construct_direct( + variables, coord_names, dims, indexes=None) return merged diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 23ee9f24871..a35f8cf02f0 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2,7 +2,7 @@ import itertools from collections import OrderedDict, defaultdict from datetime import timedelta -from typing import Tuple, Type +from typing import Tuple, Type, Union import numpy as np import pandas as pd @@ -38,7 +38,7 @@ class MissingDimensionsError(ValueError): # TODO: move this to an xarray.exceptions module? -def as_variable(obj, name=None): +def as_variable(obj, name=None) -> 'Union[Variable, IndexVariable]': """Convert an object into a Variable. Parameters diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 906ebb278cc..20872aa4088 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -683,7 +683,9 @@ def test_isel_fancy(self): da.isel(time=(('points',), [1, 2]), x=(('points',), [2, 2]), y=(('points',), [3, 4])) np.testing.assert_allclose( - da.isel_points(time=[1], x=[2], y=[4]).values.squeeze(), + da.isel(time=(('p',), [1]), + x=(('p',), [2]), + y=(('p',), [4])).values.squeeze(), np_array[1, 4, 2].squeeze()) da.isel(time=(('points', ), [1, 2])) y = [-1, 0] diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 05884bda4ba..463c6756268 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1416,7 +1416,7 @@ def test_sel_fancy(self): assert_identical(actual['b'].drop('y'), idx_y['b']) with pytest.raises(KeyError): - data.sel_points(x=[2.5], y=[2.0], method='pad', tolerance=1e-3) + data.sel(x=[2.5], y=[2.0], method='pad', tolerance=1e-3) def test_sel_method(self): data = create_test_data()