Skip to content

Commit

Permalink
Various fixes for explicit Dataset.indexes (#2858)
Browse files Browse the repository at this point in the history
* Various fixes for explicit Dataset.indexes

Fixes GH2856

I've added internal consistency checks to the uses of ``assert_equal`` in our
test suite, so this shouldn't happen again.

* Fix indexes in Dataset.interp
  • Loading branch information
shoyer authored Apr 4, 2019
1 parent aaae999 commit 31619d7
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 99 deletions.
4 changes: 3 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ Bug fixes

- Dataset.copy(deep=True) now creates a deep copy of the attrs (:issue:`2835`).
By `Andras Gefferth <https://github.com/kefirbandi>`_.
- ``swap_dims`` would create incorrect ``indexes`` (:issue:`2842`).
- Fix incorrect ``indexes`` resulting from various ``Dataset`` operations
(e.g., ``swap_dims``, ``isel``, ``reindex``, ``[]``) (:issue:`2842`,
:issue:`2856`).
By `Stephan Hoyer <https://github.com/shoyer>`_.

.. _whats-new.0.12.0:
Expand Down
63 changes: 28 additions & 35 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,36 +315,51 @@ def reindex_variables(
"""
from .dataarray import DataArray

# create variables for the new dataset
reindexed = OrderedDict() # type: OrderedDict[Any, Variable]

# build up indexers for assignment along each dimension
int_indexers = {}
targets = OrderedDict() # type: OrderedDict[Any, pd.Index]
new_indexes = OrderedDict(indexes)
masked_dims = set()
unchanged_dims = set()

# size of reindexed dimensions
new_sizes = {}
for dim, indexer in indexers.items():
if isinstance(indexer, DataArray) and indexer.dims != (dim,):
warnings.warn(
"Indexer has dimensions {0:s} that are different "
"from that to be indexed along {1:s}. "
"This will behave differently in the future.".format(
str(indexer.dims), dim),
FutureWarning, stacklevel=3)

target = new_indexes[dim] = utils.safe_cast_to_index(indexers[dim])

if dim in indexes:
index = indexes[dim]

for name, index in indexes.items():
if name in indexers:
if not index.is_unique:
raise ValueError(
'cannot reindex or align along dimension %r because the '
'index has duplicate values' % name)

target = utils.safe_cast_to_index(indexers[name])
new_sizes[name] = len(target)
'index has duplicate values' % dim)

int_indexer = get_indexer_nd(index, target, method, tolerance)

# We uses negative values from get_indexer_nd to signify
# values that are missing in the index.
if (int_indexer < 0).any():
masked_dims.add(name)
masked_dims.add(dim)
elif np.array_equal(int_indexer, np.arange(len(index))):
unchanged_dims.add(name)
unchanged_dims.add(dim)

int_indexers[name] = int_indexer
targets[name] = target
int_indexers[dim] = int_indexer

if dim in variables:
var = variables[dim]
args = (var.attrs, var.encoding) # type: tuple
else:
args = ()
reindexed[dim] = IndexVariable((dim,), target, *args)

for dim in sizes:
if dim not in indexes and dim in indexers:
Expand All @@ -356,25 +371,6 @@ def reindex_variables(
'index because its size %r is different from the size of '
'the new index %r' % (dim, existing_size, new_size))

# create variables for the new dataset
reindexed = OrderedDict() # type: OrderedDict[Any, Variable]

for dim, indexer in indexers.items():
if isinstance(indexer, DataArray) and indexer.dims != (dim,):
warnings.warn(
"Indexer has dimensions {0:s} that are different "
"from that to be indexed along {1:s}. "
"This will behave differently in the future.".format(
str(indexer.dims), dim),
FutureWarning, stacklevel=3)

if dim in variables:
var = variables[dim]
args = (var.attrs, var.encoding) # type: tuple
else:
args = ()
reindexed[dim] = IndexVariable((dim,), indexers[dim], *args)

for name, var in variables.items():
if name not in indexers:
key = tuple(slice(None)
Expand All @@ -395,9 +391,6 @@ def reindex_variables(

reindexed[name] = new_var

new_indexes = OrderedDict(indexes)
new_indexes.update(targets)

return reindexed, new_indexes


Expand Down
3 changes: 0 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,6 @@ def __init__(self, data, coords=None, dims=None, name=None,
coords, dims = _infer_coords_and_dims(data.shape, coords, dims)
variable = Variable(dims, data, attrs, encoding, fastpath=True)

# uncomment for a useful consistency check:
# assert all(isinstance(v, Variable) for v in coords.values())

# These fully describe a DataArray
self._variable = variable
self._coords = coords
Expand Down
48 changes: 31 additions & 17 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,7 @@ def _copy_listed(self: T, names) -> T:
"""
variables = OrderedDict() # type: OrderedDict[Any, Variable]
coord_names = set()
indexes = OrderedDict() # type: OrderedDict[Any, pd.Index]

for name in names:
try:
Expand All @@ -948,6 +949,8 @@ def _copy_listed(self: T, names) -> T:
variables[var_name] = var
if ref_name in self._coord_names or ref_name in self.dims:
coord_names.add(var_name)
if (var_name,) == var.dims:
indexes[var_name] = var.to_index()

needed_dims = set() # type: set
for v in variables.values():
Expand All @@ -959,12 +962,8 @@ def _copy_listed(self: T, names) -> T:
if set(self.variables[k].dims) <= needed_dims:
variables[k] = self._variables[k]
coord_names.add(k)

if self._indexes is None:
indexes = None
else:
indexes = OrderedDict((k, v) for k, v in self._indexes.items()
if k in coord_names)
if k in self.indexes:
indexes[k] = self.indexes[k]

return self._replace(variables, coord_names, dims, indexes=indexes)

Expand Down Expand Up @@ -1503,9 +1502,13 @@ def _validate_indexers(
raise ValueError("dimensions %r do not exist" % invalid)

# all indexers should be int, slice, np.ndarrays, or Variable
indexers_list = []
indexers_list = [] # type: List[Tuple[Any, Union[slice, Variable]]]
for k, v in indexers.items():
if isinstance(v, (slice, Variable)):
if isinstance(v, slice):
indexers_list.append((k, v))
continue

if isinstance(v, Variable):
pass
elif isinstance(v, DataArray):
v = v.variable
Expand All @@ -1524,14 +1527,19 @@ def _validate_indexers(
v = _parse_array_of_cftime_strings(v, index.date_type)

if v.ndim == 0:
v = as_variable(v)
v = Variable((), v)
elif v.ndim == 1:
v = as_variable((k, v))
v = IndexVariable((k,), v)
else:
raise IndexError(
"Unlabeled multi-dimensional array cannot be "
"used for indexing: {}".format(k))

if v.ndim == 1:
v = v.to_index_variable()

indexers_list.append((k, v))

return indexers_list

def _get_indexers_coords_and_indexes(self, indexers):
Expand Down Expand Up @@ -1631,7 +1639,7 @@ def isel(self, indexers=None, drop=False, **indexers_kwargs):

if name in self.indexes:
new_var, new_index = isel_variable_and_index(
var, self.indexes[name], var_indexers)
name, var, self.indexes[name], var_indexers)
if new_index is not None:
indexes[name] = new_index
else:
Expand Down Expand Up @@ -2117,15 +2125,20 @@ def _validate_interp_indexer(x, new_x):
indexes = OrderedDict(
(k, v) for k, v in obj.indexes.items() if k not in indexers)
selected = self._replace_with_new_dims(
variables, coord_names, indexes=indexes)
variables.copy(), coord_names, indexes=indexes)

# attach indexer as coordinate
variables.update(indexers)
indexes.update(
(k, v.to_index()) for k, v in indexers.items() if v.dims == (k,)
)

# Extract coordinates from indexers
coord_vars, new_indexes = (
selected._get_indexers_coords_and_indexes(coords))
variables.update(coord_vars)
indexes.update(new_indexes)

coord_names = (set(variables)
.intersection(obj._coord_names)
.union(coord_vars))
Expand Down Expand Up @@ -2401,6 +2414,7 @@ def expand_dims(self, dim=None, axis=None, **dim_kwargs):
' variable name.'.format(dim=d))

variables = OrderedDict()
coord_names = self._coord_names.copy()
# If dim is a dict, then ensure that the values are either integers
# or iterables.
for k, v in dim.items():
Expand All @@ -2410,7 +2424,7 @@ def expand_dims(self, dim=None, axis=None, **dim_kwargs):
# value within the dim dict to the length of the iterable
# for later use.
variables[k] = xr.IndexVariable((k,), v)
self._coord_names.add(k)
coord_names.add(k)
dim[k] = variables[k].size
elif isinstance(v, int):
pass # Do nothing if the dimensions value is just an int
Expand All @@ -2420,7 +2434,7 @@ def expand_dims(self, dim=None, axis=None, **dim_kwargs):

for k, v in self._variables.items():
if k not in dim:
if k in self._coord_names: # Do not change coordinates
if k in coord_names: # Do not change coordinates
variables[k] = v
else:
result_ndim = len(v.dims) + len(axis)
Expand Down Expand Up @@ -2452,10 +2466,10 @@ def expand_dims(self, dim=None, axis=None, **dim_kwargs):
variables[k] = v.set_dims(k)

new_dims = self._dims.copy()
for d in dim:
new_dims[d] = 1
new_dims.update(dim)

return self._replace(variables, dims=new_dims)
return self._replace_vars_and_dims(
variables, dims=new_dims, coord_names=coord_names)

def set_index(self, indexes=None, append=False, inplace=None,
**indexes_kwargs):
Expand Down
7 changes: 4 additions & 3 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections.abc
from collections import OrderedDict
from typing import Any, Iterable, Mapping, Optional, Tuple, Union
from typing import Any, Hashable, Iterable, Mapping, Optional, Tuple, Union

import pandas as pd

Expand Down Expand Up @@ -59,6 +59,7 @@ def default_indexes(


def isel_variable_and_index(
name: Hashable,
variable: Variable,
index: pd.Index,
indexers: Mapping[Any, Union[slice, Variable]],
Expand All @@ -75,8 +76,8 @@ def isel_variable_and_index(

new_variable = variable.isel(indexers)

if new_variable.ndim != 1:
# can't preserve a index if result is not 0D
if new_variable.dims != (name,):
# can't preserve a index if result has new dimensions
return new_variable, None

# we need to compute the new index
Expand Down
43 changes: 38 additions & 5 deletions xarray/testing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Testing functions exposed to the user API"""
from collections import OrderedDict

import numpy as np
import pandas as pd

from xarray.core import duck_array_ops
from xarray.core import formatting
from xarray.core.indexes import default_indexes


def _decode_string_data(data):
Expand Down Expand Up @@ -143,8 +147,37 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
.format(type(a)))


def assert_combined_tile_ids_equal(dict1, dict2):
assert len(dict1) == len(dict2)
for k, v in dict1.items():
assert k in dict2.keys()
assert_equal(dict1[k], dict2[k])
def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims):
import xarray as xr

assert isinstance(indexes, OrderedDict), indexes
assert all(isinstance(v, pd.Index) for v in indexes.values()), \
{k: type(v) for k, v in indexes.items()}

index_vars = {k for k, v in possible_coord_variables.items()
if isinstance(v, xr.IndexVariable)}
assert indexes.keys() <= index_vars, (set(indexes), index_vars)

# Note: when we support non-default indexes, these checks should be opt-in
# only!
defaults = default_indexes(possible_coord_variables, dims)
assert indexes.keys() == defaults.keys(), \
(set(indexes), set(defaults))
assert all(v.equals(defaults[k]) for k, v in indexes.items()), \
(indexes, defaults)


def _assert_indexes_invariants(a):
"""Separate helper function for checking indexes invariants only."""
import xarray as xr

if isinstance(a, xr.DataArray):
if a._indexes is not None:
_assert_indexes_invariants_checks(a._indexes, a._coords, a.dims)
elif isinstance(a, xr.Dataset):
if a._indexes is not None:
_assert_indexes_invariants_checks(
a._indexes, a._variables, a._dims)
elif isinstance(a, xr.Variable):
# no indexes
pass
25 changes: 23 additions & 2 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from xarray.core import utils
from xarray.core.options import set_options
from xarray.core.indexing import ExplicitlyIndexed
from xarray.testing import (assert_equal, assert_identical, # noqa: F401
assert_allclose, assert_combined_tile_ids_equal)
import xarray.testing
from xarray.plot.utils import import_seaborn

try:
Expand Down Expand Up @@ -180,3 +179,25 @@ def source_ndarray(array):
if base is None:
base = array
return base


# Internal versions of xarray's test functions that validate additional
# invariants
# TODO: add more invariant checks.

def assert_equal(a, b):
xarray.testing.assert_equal(a, b)
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)


def assert_identical(a, b):
xarray.testing.assert_identical(a, b)
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)


def assert_allclose(a, b, **kwargs):
xarray.testing.assert_allclose(a, b, **kwargs)
xarray.testing._assert_indexes_invariants(a)
xarray.testing._assert_indexes_invariants(b)
Loading

0 comments on commit 31619d7

Please sign in to comment.