Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Periodic Boundary Index #7031

Open
TomNicholas opened this issue Sep 13, 2022 · 14 comments
Open

Periodic Boundary Index #7031

TomNicholas opened this issue Sep 13, 2022 · 14 comments

Comments

@TomNicholas
Copy link
Member

TomNicholas commented Sep 13, 2022

What is your issue?

I would like to create a PeriodicBoundaryIndex using the Explicit Indexes refactor. I want to do it first in 1D, then 2D, then maybe ND.

I'm thinking this would be useful for:

  1. Geoscientists with periodic longitudes
  2. Any scientists with periodic domains
  3. Road-testing the refactor + how easy the documentation is to follow.

Eventually I think perhaps this index should live in xarray itself? As it's domain-agnostic, doesn't introduce extra dependencies, and could be a conceptually simple example of a custom index.

I had a first go, using the benbovy:add-set-xindex-and-drop-indexes branch, and reading the in-progress docs page. I got a bit stuck early on though.

@benbovy here's what I have so far:

import numpy as np
import pandas as pd
import xarray as xr
from xarray.core.variable import Variable
from xarray.core.indexes import PandasIndex, is_scalar

from typing import Union, Mapping, Any


class PeriodicBoundaryIndex(PandasIndex):
    """
    An index representing any 1D periodic numberline.
    
    Implementation subclasses a normal xarray PandasIndex object but intercepts indexer queries.
    """
        
    def _periodic_subset(self, indxr: Union[int, slice, np.ndarray]) -> pd.Index:
        """Equivalent of __getitem__ for a pd.Index, but respects periodicity."""
        
        length = len(self)
        
        if isinstance(indxr, int):
            return self.index[indxr % length]
        elif isinstance(indxr, slice):
            raise NotImplementedError()
        elif isinstance(indxr, np.ndarray):
            raise NotImplementedError()
        else:
            raise TypeError    
    
    def isel(
        self, indexers: Mapping[Any, Union[int, slice, np.ndarray, Variable]]
    ) -> Union["PeriodicBoundaryIndex", None]:

        print("isel called")

        indxr = indexers[self.dim]
        if isinstance(indxr, Variable):
            if indxr.dims != (self.dim,):
                # can't preserve a index if result has new dimensions
                return None
            else:
                indxr = indxr.data
        if not isinstance(indxr, slice) and is_scalar(indxr):
            # scalar indexer: drop index
            return None

        subsetted_index = self._periodic_subset[indxr]
        return self._replace(subsetted_index)
airtemps = xr.tutorial.open_dataset("air_temperature")['air']

da = airtemps.drop_indexes("lon")

world = da.set_xindex("lon", index_cls=PeriodicBoundaryIndex)

Now selecting a value with isel inside the range works fine, giving the same result same as without my custom index. (The length of the example dataset along lon is 53.)

world.isel(lon=45)
isel called
<xarray.DataArray 'air' (time: 2920, lat: 25)>
...

But indexing with a lon value outside the range of the index data gives an IndexError, seemingly without consulting my new index object. It didn't even print "isel called" 😕 What should I have implemented that I didn't implement?

world.isel(lon=55)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Input In [35], in <cell line: 1>()
----> 1 world.isel(lon=55)

File ~/Documents/Work/Code/xarray/xarray/core/dataarray.py:1297, in DataArray.isel(self, indexers, drop, missing_dims, **indexers_kwargs)
   1292     return self._from_temp_dataset(ds)
   1294 # Much faster algorithm for when all indexers are ints, slices, one-dimensional
   1295 # lists, or zero or one-dimensional np.ndarray's
-> 1297 variable = self._variable.isel(indexers, missing_dims=missing_dims)
   1298 indexes, index_variables = isel_indexes(self.xindexes, indexers)
   1300 coords = {}

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:1233, in Variable.isel(self, indexers, missing_dims, **indexers_kwargs)
   1230 indexers = drop_dims_from_indexers(indexers, self.dims, missing_dims)
   1232 key = tuple(indexers.get(dim, slice(None)) for dim in self.dims)
-> 1233 return self[key]

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:793, in Variable.__getitem__(self, key)
    780 """Return a new Variable object whose contents are consistent with
    781 getting the provided key from the underlying data.
    782 
   (...)
    790 array `x.values` directly.
    791 """
    792 dims, indexer, new_order = self._broadcast_indexes(key)
--> 793 data = as_indexable(self._data)[indexer]
    794 if new_order:
    795     data = np.moveaxis(data, range(len(new_order)), new_order)

File ~/Documents/Work/Code/xarray/xarray/core/indexing.py:657, in MemoryCachedArray.__getitem__(self, key)
    656 def __getitem__(self, key):
--> 657     return type(self)(_wrap_numpy_scalars(self.array[key]))

File ~/Documents/Work/Code/xarray/xarray/core/indexing.py:626, in CopyOnWriteArray.__getitem__(self, key)
    625 def __getitem__(self, key):
--> 626     return type(self)(_wrap_numpy_scalars(self.array[key]))

File ~/Documents/Work/Code/xarray/xarray/core/indexing.py:533, in LazilyIndexedArray.__getitem__(self, indexer)
    531     array = LazilyVectorizedIndexedArray(self.array, self.key)
    532     return array[indexer]
--> 533 return type(self)(self.array, self._updated_key(indexer))

File ~/Documents/Work/Code/xarray/xarray/core/indexing.py:505, in LazilyIndexedArray._updated_key(self, new_key)
    503         full_key.append(k)
    504     else:
--> 505         full_key.append(_index_indexer_1d(k, next(iter_new_key), size))
    506 full_key = tuple(full_key)
    508 if all(isinstance(k, integer_types + (slice,)) for k in full_key):

File ~/Documents/Work/Code/xarray/xarray/core/indexing.py:278, in _index_indexer_1d(old_indexer, applied_indexer, size)
    276         indexer = slice_slice(old_indexer, applied_indexer, size)
    277     else:
--> 278         indexer = _expand_slice(old_indexer, size)[applied_indexer]
    279 else:
    280     indexer = old_indexer[applied_indexer]

IndexError: index 55 is out of bounds for axis 0 with size 53
@keewis
Copy link
Collaborator

keewis commented Sep 13, 2022

shouldn't you implement (and call) sel instead of isel if you're working in coordinate space?

@TomNicholas
Copy link
Member Author

TomNicholas commented Sep 13, 2022 via email

@benbovy
Copy link
Member

benbovy commented Sep 14, 2022

tl;dr: Xarray Index currently supports implementing periodic indexing for label-based indexing but not for location-based (integer) indexing.

There's a big difference now between isel and sel:

  • Dataset.isel() accepts dimension names only
  • Dataset.sel() accepts coordinate names (actually, it falls back to isel when giving dimension names with no coordinate, and I'm wondering if we shouldn't deprecate that?)

Index.isel() is convenient when the underlying index structure can be itself sliced (like pandas.Index objects), so that users don't need to do ds.isel(...).set_xindex(...) every time to explicitly rebuild an index after slicing the Dataset. For a kd-tree structure that may not be possible, i.e., KDTreeIndex.isel() would likely return None causing the index to be dropped in the result, so there would be no way around doing ds.isel(...).set_xindex(...).

Most coordinate and data variables are still sliced via Variable.isel(), which doesn't involve any index. That's why you get an IndexError in your example. (side note: the "index" / "indexing" terminology used everywhere, for both label and integer selection, is quite confusing but I'm not sure how this could be improved).

If we want to support periodic indexing with isel, we would have to implement that in Xarray itself. Alternatively, it might be possible to add some API in Index so that in the case of a periodic index it would return indxr % length from indxr, which Xarray will then pass to Variable.isel(). I'm not sure the latter is a good idea, though. Indexes may work with arbitrary coordinates and dimensions, which would make things too complex (handling conflicts, etc.). Also, I don't know if there's other potential use cases besides periodic indexing?

@TomNicholas your experiment makes it clear that the documentation on this part (#6975) should be improved. Thanks!

@TomNicholas
Copy link
Member Author

TomNicholas commented Sep 14, 2022 via email

@benbovy
Copy link
Member

benbovy commented Sep 14, 2022

My understanding from reading the docs was that every Dataset.meth calls the corresponding Index.meth.

Yes that's indeed what I've written in #6975 and I realize now that this is confusing, especially for isel.

So Dataset.sel calls Index.sel, but can also sometimes call Dataset.isel. But Dataset.isel does not call Index.isel, nor Index.sel.

So we can describe the implementation of Dataset.sel() as a two-step procedure:

  1. remap the input dictionary {coord_name: label_values} to a dictionary {dimension_name: int_positions}.

    • This is done via dispatching the input dictionary and calling Index.sel() for each of the relevant indexes found in Dataset.xindexes, and then merging all the returned results into a single output dictionary.
  2. pass the the dictionary {dimension_name: int_positions} to Dataset.isel().

    • Dataset.isel() will dispatch this input dictionary and call Variable.isel() for each variable in Dataset.variables and Index.isel() for each unique index in Dataset.xindexes.

This omits a few implementation details (special cases for multi-index), but that's basically how it works.

I think it would help if such "how label-based selection works in Xarray" high-level description was added somewhere in the "Xarray internals" documentation, along with other "how it works" sections for, e.g., alignment.

@TomNicholas
Copy link
Member Author

TomNicholas commented Sep 14, 2022 via email

@TomNicholas
Copy link
Member Author

TomNicholas commented Sep 14, 2022

I had another go and now I have this (the .sel method is just copied from PandasIndex.sel with minor changes):

from xarray.core.indexes import (
    PandasIndex, is_scalar, as_scalar, get_indexer_nd, IndexSelResult, 
    _query_slice, is_dict_like, normalize_label,
)


class PeriodicBoundaryIndex(PandasIndex):
    """
    An index representing any 1D periodic numberline.
    
    Implementation subclasses a normal xarray PandasIndex object but intercepts indexer queries.
    """
    period: float
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.period = 360  # TODO work out where this input should be passed in instead of hard-coding
    
    def _wrap_periodically(self, label_value):
        return self.index.min() + (label_value - self.index.max()) % self.period  
    
    def sel(
        self, labels: dict[Any, Any], method=None, tolerance=None
    ) -> IndexSelResult:
        """Remaps labels outside of the indexes' range back to integer indices inside the range."""
        
        print("sel called")
        
        from xarray import DataArray
        from xarray import Variable

        if method is not None and not isinstance(method, str):
            raise TypeError("``method`` must be a string")

        assert len(labels) == 1
        coord_name, label = next(iter(labels.items()))

        if isinstance(label, slice):
            print(label)
            indexer = _query_slice(self.index, label, coord_name, method, tolerance)
            print(indexer)
        elif is_dict_like(label):
            raise ValueError(
                "cannot use a dict-like object for selection on "
                "a dimension that does not have a MultiIndex"
            )
        else:
            label_array = normalize_label(label, dtype=self.coord_dtype)
            if label_array.ndim == 0:
                label_array = self._wrap_periodically(label_array)
                label_value = as_scalar(label_array)
                if isinstance(self.index, pd.CategoricalIndex):
                    if method is not None:
                        raise ValueError(
                            "'method' is not supported when indexing using a CategoricalIndex."
                        )
                    if tolerance is not None:
                        raise ValueError(
                            "'tolerance' is not supported when indexing using a CategoricalIndex."
                        )
                    indexer = self.index.get_loc(label_value)
                else:
                    if method is not None:
                        print(label_array)
                        indexer = get_indexer_nd(
                            self.index, label_array, method, tolerance
                        )
                        if np.any(indexer < 0):
                            raise KeyError(
                                f"not all values found in index {coord_name!r}"
                            )
                    else:
                        try:
                            print(label_value)
                            indexer = self.index.get_loc(label_value)
                        except KeyError as e:
                            raise KeyError(
                                f"not all values found in index {coord_name!r}. "
                                "Try setting the `method` keyword argument (example: method='nearest')."
                            ) from e

            elif label_array.dtype.kind == "b":
                indexer = label_array
            else:
                indexer = get_indexer_nd(self.index, label_array, method, tolerance)
                if np.any(indexer < 0):
                    raise KeyError(f"not all values found in index {coord_name!r}")

            # attach dimension names and/or coordinates to positional indexer
            if isinstance(label, Variable):
                indexer = Variable(label.dims, indexer)
            elif isinstance(label, DataArray):
                indexer = DataArray(indexer, coords=label._coords, dims=label.dims)

        return IndexSelResult({self.dim: indexer})

    def isel(
        self, indexers: Mapping[Any, Union[int, slice, np.ndarray, Variable]]
    ) -> Union["PeriodicBoundaryIndex", None]:

        print("isel called")
        return super().isel(indexers=indexers)

This works for integer indexing with sel!

lon_coord = xr.DataArray(data=np.linspace(-180, 180, 19), dims="lon")
da = xr.DataArray(data=np.random.randn(19), dims="lon", coords={"lon": lon_coord})
<xarray.DataArray (lon: 19)>
array([-0.67423202,  0.14173693, -0.51427002,  1.25764101,  0.23863066,
        0.05703135, -0.65350384, -0.74356356,  0.98524252, -0.94975665,
        0.63314842, -0.7144752 ,  0.47282375,  0.31555171, -0.13179154,
       -1.10255267,  0.88180541,  1.28461459,  1.61273741])
Coordinates:
  * lon      (lon) float64 -180.0 -160.0 -140.0 -120.0 ... 140.0 160.0 180.0
world = da.drop_indexes("lon").set_xindex("lon", index_cls=PeriodicBoundaryIndex)

world.sel(lon=200, method="nearest")
<xarray.DataArray ()>
array(0.14173693)
Coordinates:
    lon      float64 -160.0

Yay! 🍾

Q: Best way to do this for slicing?

I want this to work

world.sel(lon=slice(170, 190))

Internally that means PeriodicBoundaryIndex.sel has to return an indexer that points to values at both the start and end of the array data. I'm not sure what the best way of doing this is. Originally I imagined returning two slices but I don't think that's a valid argument to Dataset.isel().

So I guess I have to turn my slice into a list of specific integer positions and pass that to .isel()? How do I do that? Is that going to be inefficient somehow?

I guess I also want to reorder the result before returning it, otherwise the two sides of the dateline won't be stitched together in the right order...

Q: Where should I pass in period?

If I want the period of the PeriodicBoundaryIndex to be a general parameter, independent of the data in the array, the attributes, or the values of the index labels, where would be the most sensible place to pass this in? .set_indexes only accepts a class, not an instance, and I can't use .from_variables as it can't be deduced from the variables in general, so where can I feed it in?

Q: How to avoid just copying all of PandasIndex.sel's implementation?

I find myself copying the entire implementation of PandasIndex.sel just to insert 1 or two lines in predictable places.

Also pointing me to looking at the implementation of PandasIndex is going to lead to me using lots of private functions from xarray.indexes, because I have to import them in order to copy-paste code from PandasIndex.

I wonder if these problems could be ameliorated by providing public entry methods in the Index superclass? I'm thinking about how in anytree (which I used to make the first prototype of datatree) there are these methods that do nothing by default but are intended to be overridden to insert functionality at key steps. The pattern is basically this:

# library code
class NodeMixin:
    """Inherit from this to create your own TreeNode class with parent and children"""

    def _pre_detach_children(self, children):
        """Method call before detaching `children`."""
        pass

    def _post_detach_children(self, children):
        """Method call after detaching `children`."""
        pass

    def _pre_attach_children(self, children):
        """Method call before attaching `children`."""
        pass

    def _post_attach_children(self, children):
        """Method call after attaching `children`."""
        pass


# user code
class MyHappyTreeNode(NodeMixin):
    def _pre_attach_children(self, children):
        """Celebrates the gift of children"""
        print("A child is born!")

What if we put similar methods on the PandasIndex superclass? Like

class PandasIndex:
    def _post_process_label_value(self, label_value: float) -> float:
        """Method call after determining scalar label value."""
        return label_value

    def sel(
        self, labels: dict[Any, Any], method=None, tolerance=None
    ) -> IndexSelResult:
        # rest of the function as before
        ...

        if isinstance(label, slice):
            indexer = _query_slice(self.index, label, coord_name, method, tolerance)
        elif is_dict_like(label):
            raise ValueError(
                "cannot use a dict-like object for selection on "
                "a dimension that does not have a MultiIndex"
            )
        else:
            label_array = normalize_label(label, dtype=self.coord_dtype)
            if label_array.ndim == 0:
                label_value = as_scalar(label_array)
                label_value = self._post_process_label_value(label_value)  # new bit
                ...
        # rest of the function as before

Then in my case I would not have had to copy so much of the implementation, I could have simply done

class PeriodicBoundaryIndex(PandasIndex):
    period: float
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.period = 360  # TODO work out where this input should be passed in instead of hard-coding
    
    def _wrap_periodically(self, label_value):
        return self.index.min() + (label_value - self.index.max()) % self.period

    def _post_process_label_value(self, label_value):
        return self._wrap_periodically(label_value)

Maybe this is a bad idea / doesn't make sense but I thought I would suggest it anyway.

@benbovy
Copy link
Member

benbovy commented Sep 15, 2022

Great @TomNicholas!

To avoid copying the body of PandasIndex.sel, couldn't you "just" do something like this?

class PeriodicBoundaryIndex(PandasIndex):
    """
    An index representing any 1D periodic numberline.
    
    Implementation subclasses a normal xarray PandasIndex object but intercepts indexer queries.
    """
    period: float
    
    def __init__(self, *args, period=360, **kwargs):
        super().__init__(*args, **kwargs)
        self.period = period
        
    @classmethod
    def from_variables(self, variables, options):
        obj = super().from_variables(variables, options={})
        obj.period = options.get("period", obj.period)
        return obj
    
    def _wrap_periodically(self, label_value):
        return self.index.min() + (label_value - self.index.max()) % self.period  
    
    def sel(
        self, labels: dict[Any, Any], method=None, tolerance=None
    ) -> IndexSelResult:
        """Remaps labels outside of the indexes' range back to integer indices inside the range."""

        assert len(labels) == 1
        coord_name, label = next(iter(labels.items()))
        
        if isinstance(label, slice):
            wrapped_label = slice(
                self._wrap_periodically(label.start),
                self._wrap_periodically(label.stop),
            )
        else:
            wrapped_label = self._wrap_periodically(label)
            
        return super().sel({coord_name: wrapped_label})

Note: I also added period as an option, which is supported in #6971 but not yet well documented. Another way to pass options is via coordinate attributes, like in this FunctionalIndex example.

It should work in most cases I think:

lon_coord = xr.DataArray(data=np.linspace(-180, 180, 19), dims="lon")
da = xr.DataArray(data=np.random.randn(19), dims="lon", coords={"lon": lon_coord})

# note the period set here
world = da.drop_indexes("lon").set_xindex("lon", index_cls=PeriodicBoundaryIndex, period=360)
world.sel(lon=200, method="nearest")
# <xarray.DataArray ()>
# array(-0.86583185)
# Coordinates:
#     lon      float64 -160.0

world.sel(lon=[200, 200], method="nearest")
# <xarray.DataArray (lon: 2)>
# array([-0.86583185, -0.86583185])
# Coordinates:
#   * lon      (lon) float64 -160.0 -160.0

world.sel(lon=slice(180, 200), method="nearest")
# <xarray.DataArray (lon: 2)>
# array([-1.59829997, -0.86583185])
# Coordinates:
#   * lon      (lon) float64 -180.0 -160.0

There's likely more things to do for slices as you point out. I don't think either that it's possible to pass two slices to isel. Not sure how this could be handled, but probably the easiest is to raise for cases like world.sel(lon=slice(170, 190)).

If we really need more flexibility in sel without copying the whole body of PandasIndex.sel, we could indeed refactor PandasIndex to allow more customization in subclasses. We must be careful, though, as it may be harder to make changes without possibly breaking 3rd-party stuff.

Or like you suggest we could define some _pre_process / _post_process hooks. It's not obvious where to call those hooks, though. Before or after converting from/to Variable or DataArray? Before or after checking for slices? array or scalar? The ideal place may change from one index to another.

@TomNicholas
Copy link
Member Author

Nice @benbovy ! That seems useable already - I'll open a PR to work on it more. It's also much neater 😄

Note: I also added period as an option

Great, I was hoping there was some functionality like that.

Not sure how this could be handled, but probably the easiest is to raise for cases like world.sel(lon=slice(170, 190)).

That would be a shame IMO, so I'll have more of think about how to handle slicing across the dateline.

We must be careful, though, as it may be harder to make changes without possibly breaking 3rd-party stuff.

It's not obvious where to call those hooks, though.

All good points.

The other thing I will think about is whether anything special needs to happen for 2D+ periodicity. I suspect that for integers you could just use independent 1D indexes along each dim but for slicing across the "dateline" it might get messy in higher dimensions...

@dcherian
Copy link
Contributor

In general, it seems like most (nearly all?) 1D indexing use cases can be handled by encapsulating a PandasIndex (see also https://github.com/dcherian/crsindex). So perhaps we should just recommend that and add a lot more comments to PandasIndex to make it easier to build on.

@headtr1ck
Copy link
Collaborator

Not sure how this could be handled, but probably the easiest is to raise for cases like world.sel(lon=slice(170, 190)).

One could split it into two calls to isel and concatenate the result. Not sure if that's possible with the given interface.

@TomNicholas
Copy link
Member Author

TomNicholas commented Sep 15, 2022

Okay I think this design could work for slicing across boundaries:

from xarray.core.indexes import PandasIndex, IndexSelResult, _query_slice
from xarray.core.indexing import _expand_slice


class PeriodicBoundaryIndex(PandasIndex):
    """
    An index representing any 1D periodic numberline.
    
    Implementation subclasses a normal xarray PandasIndex object but intercepts indexer queries.
    """
    period: float
    _min: float
    _max: float
    
    __slots__ = ("index", "dim", "coord_dtype", "period", "_max", "_min")
    
    def __init__(self, *args, period=360, **kwargs):
        super().__init__(*args, **kwargs)
        self.period = period
        self._min = self.index.min()
        self._max = self.index.max()
        
    @classmethod
    def from_variables(self, variables, options):
        obj = super().from_variables(variables, options={})
        obj.period = options.get("period", obj.period)
        return obj
    
    def _wrap_periodically(self, label_value: float) -> float:
        """Remaps an individual point label back to another inside the range."""
        return self._min + (label_value - self._max) % self.period  
        
    def _split_slice_across_boundary(self, label: slice) -> np.ndarray:
        """
        Splits a slice into two slices, one either side of the boundary,
        finds the corresponding indices, concatenates them, 
        and returns them ready to be passed to .isel().
        """
        first_slice = slice(label.start, self._max, label.step)
        second_slice = slice(self._min, self._wrap_periodically(label.stop), label.step)
        
        first_as_index_slice = _query_slice(self.index, first_slice)
        second_as_index_slice = _query_slice(self.index, second_slice)
        
        first_as_indices = _expand_slice(first_as_index_slice, self.index.size)
        second_as_indices = _expand_slice(second_as_index_slice, self.index.size)
        
        wrapped_indices = np.concatenate([first_as_indices, second_as_indices])
        return wrapped_indices
    
    def sel(
        self, labels: dict[Any, Any], method=None, tolerance=None
    ) -> IndexSelResult:
        """Remaps labels outside of the indexes' range back to integer indices inside the range."""

        assert len(labels) == 1
        coord_name, label = next(iter(labels.items()))
        
        if isinstance(label, slice):
            # TODO enumerate all the possible cases
            if self._min < label.start < self._max and self._min < label.stop < self._max:
                # simple case of slice not crossing boundary
                wrapped_label = slice(
                    self._wrap_periodically(label.start),
                    self._wrap_periodically(label.stop),
                )
                return super().sel({coord_name: wrapped_label})
            elif self._min < label.start < self._max and label.start < self._max < label.stop:
                # nasty case of slice crossing boundary
                wrapped_indices = self._split_slice_across_boundary(label)
                return IndexSelResult({self.dim: wrapped_indices})
            else:
                # TODO there are many other cases to handle...
                raise NotImplementedError()
        else:
            # just a scalar
            wrapped_label = self._wrap_periodically(label)
            return super().sel({coord_name: wrapped_label}, method=method, tolerance=tolerance)
    
    def __repr__(self) -> str:
        return f"PeriodicBoundaryIndex(period={self.period})"
world.sel(lon=slice(60, 120), method="nearest")
# <xarray.DataArray (lon: 4)>
# array([-0.71424378, -0.87270922, -0.9701637 , -0.99979417])
# Coordinates:
#   * lon      (lon) float64 60.0 80.0 100.0 120.0

This works even for slices that cross the dateline

world.sel(lon=slice(160, 210), method="nearest")
# <xarray.DataArray (lon: 4)>
# array([-0.85218366, -0.68526211,  0.68526211,  0.85218366])
# Coordinates:
#   * lon      (lon) float64 160.0 180.0 -180.0 -160.0

This isn't general yet, there are lots of edge cases this would fail on, but I think it shows that as long as each case is captured we always could use this approach to remap back to index values that do lie within the range? What do people think?

EDIT:

One could split it into two calls to isel and concatenate the result. Not sure if that's possible with the given interface.

I believe what I've done here is the closest thing to that that is possible with the given interface.

@TomNicholas
Copy link
Member Author

TomNicholas commented Sep 16, 2022

I think this version does something sensible for all slice cases

from xarray.core.indexes import (
    PandasIndex, IndexSelResult, _query_slice
)
from xarray.core.indexing import _expand_slice


class PeriodicBoundaryIndex(PandasIndex):
    """
    An index representing any 1D periodic numberline.
    
    Implementation subclasses a normal xarray PandasIndex object but intercepts indexer queries.
    """
    period: float
    _min: float
    _max: float
    
    __slots__ = ("index", "dim", "coord_dtype", "period", "_max", "_min")
    
    def __init__(self, *args, period=360, **kwargs):
        super().__init__(*args, **kwargs)
        self.period = period
        self._min = self.index.min()
        self._max = self.index.max()
        
    @classmethod
    def from_variables(self, variables, options):
        obj = super().from_variables(variables, options={})
        obj.period = options.get("period", obj.period)
        return obj
    
    def _wrap_periodically(self, label_value: float) -> float:
        return self._min + (label_value - self._max) % self.period  
        
    def _split_slice_across_boundary(self, label: slice) -> np.ndarray:
        """
        Splits a slice into two slices, one either side of the boundary,
        finds the corresponding indices, concatenates them, and returns them,
        ready to be passed to .isel().
        """
        first_slice = slice(label.start, self._max, label.step)
        second_slice = slice(self._min, label.stop, label.step)
        
        first_as_index_slice = _query_slice(self.index, first_slice)
        second_as_index_slice = _query_slice(self.index, second_slice)
        
        first_as_indices = _expand_slice(first_as_index_slice, self.index.size)
        second_as_indices = _expand_slice(second_as_index_slice, self.index.size)
        
        wrapped_indices = np.concatenate([first_as_indices, second_as_indices])
        return wrapped_indices
    
    def sel(
        self, labels: dict[Any, Any], method=None, tolerance=None
    ) -> IndexSelResult:
        """Remaps labels outside of the indexes' range back to integer indices inside the range."""

        assert len(labels) == 1
        coord_name, label = next(iter(labels.items()))
        
        if isinstance(label, slice):
            start, stop, step = label.start, label.stop, label.step
            if stop < start:
                return super().sel({coord_name: []})
            
            assert self._min < self._max
            
            wrapped_start = self._wrap_periodically(label.start)
            wrapped_stop = self._wrap_periodically(label.stop)
            wrapped_label = slice(wrapped_start, wrapped_stop, step)
            
            if wrapped_start < wrapped_stop:
                # simple case of slice not crossing boundary
                return super().sel({coord_name: wrapped_label})
            else:  # wrapped_stop < wrapped_start:
                # nasty case of slice crossing boundary
                wrapped_indices = self._split_slice_across_boundary(wrapped_label)
                return IndexSelResult({self.dim: wrapped_indices})
            
        else:
            # just a scalar / array of scalars
            wrapped_label = self._wrap_periodically(label)
            return super().sel({coord_name: wrapped_label}, method=method, tolerance=tolerance)
    
    def __repr__(self) -> str:
        return f"PeriodicBoundaryIndex(period={self.period})"
lon_coord = xr.DataArray(data=np.linspace(-180, 180, 19), dims="lon")
da = xr.DataArray(data=np.sin(180*lon_coord), dims="lon", coords={"lon": lon_coord})

world = da.drop_indexes("lon").set_xindex("lon", index_cls=PeriodicBoundaryIndex, period=360)
world.sel(lon=slice(60, 120), method="nearest")
# <xarray.DataArray (lon: 4)>
# array([-0.71424378, -0.87270922, -0.9701637 , -0.99979417])
# Coordinates:
#   * lon      (lon) float64 60.0 80.0 100.0 120.0
world.sel(lon=slice(160, 210), method="nearest")
# <xarray.DataArray (lon: 4)>
# array([-0.85218366, -0.68526211,  0.68526211,  0.85218366])
# Coordinates:
#   * lon      (lon) float64 160.0 180.0 -180.0 -160.0
world.sel(lon=slice(-210, -160), method="nearest")
# <xarray.DataArray (lon: 4)>
# array([-0.85218366, -0.68526211,  0.68526211,  0.85218366])
# Coordinates:
#   * lon      (lon) float64 160.0 180.0 -180.0 -160.0

Unsure as to whether this next one counts as an "intuitive" result or not

world.sel(lon=slice(-210, 210), method="nearest")
# <xarray.DataArray (lon: 4)>
# array([-0.85218366, -0.68526211,  0.68526211,  0.85218366])
# Coordinates:
#   * lon      (lon) float64 160.0 180.0 -180.0 -160.0
world.sel(lon=slice(120, 60), method="nearest")
# <xarray.DataArray (lon: 0)>
# array([], dtype=float64)
# Coordinates:
#   * lon      (lon) float64

@benbovy
Copy link
Member

benbovy commented Sep 16, 2022

In general, it seems like most (nearly all?) 1D indexing use cases can be handled by encapsulating a PandasIndex (see also https://github.com/dcherian/crsindex). So perhaps we should just recommend that and add a lot more comments to PandasIndex to make it easier to build on.

I've created a MultiPandasIndex helper class for that purpose: notebook.

I've extracted the boilerplate from @dcherian's CRSIndex and I've implemented the remaining Index API. It raised a couple of issues, notably for Index.isel which should probably return a dict[Hashable, Index] instead of Index | None (the latter is not flexible enough, i.e., when the dimensions of the meta-index are partially dropped in the selection).

The other thing I will think about is whether anything special needs to happen for 2D+ periodicity. I suspect that for integers you could just use independent 1D indexes along each dim but for slicing across the "dateline" it might get messy in higher dimensions...

Yeah I guess it will work well with independent PeriodicBoundaryIndex instances (possibly grouped in a MultiPandasIndex) for gridded data.

For multi-dimension coordinates with periodic boundaries this would probably be best handled by more specific indexes, e.g., xoak's s2point index that supports periodicity for lat/lon data (I think).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants