Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sel along 1D non-index coordinates #3925

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,27 +205,35 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No

def get_dim_indexers(data_obj, indexers):
"""Given a xarray data object and label based indexers, return a mapping
of label indexers with only dimension names as keys.
of label indexers with only dimension names (or 1D non-dimensional coord
names) as keys.

It groups multiple level indexers given on a multi-index dimension
into a single, dictionary indexer for that dimension (Raise a ValueError
if it is not possible).
"""
non_dim_1d_coords = [coord for coord in data_obj.coords
if len(data_obj[coord].dims) == 1]
invalid = [
k
for k in indexers
if k not in data_obj.dims and k not in data_obj._level_coords
k for k in indexers
if k not in data_obj.dims and k not in data_obj._level_coords
and k not in non_dim_1d_coords
]
if invalid:
raise ValueError(f"dimensions or multi-index levels {invalid!r} do not exist")
if invalid: # TODO This was never covered by testing
raise ValueError(f"dimensions, 1D coordinates, or multi-index levels"
f" {invalid!r} do not exist")

level_indexers = defaultdict(dict)
dim_indexers = {}
for key, label in indexers.items():
(dim,) = data_obj[key].dims
if key != dim:
# assume here multi-index level indexer
level_indexers[dim][key] = label
# If key is 1D non-dimension coordinate let it pass through
if key in data_obj.coords and data_obj.coords[key].dims == (dim,):
dim_indexers[key] = label
else:
# assume here multi-index level indexer
level_indexers[dim][key] = label
else:
dim_indexers[key] = label

Expand Down Expand Up @@ -253,9 +261,31 @@ def remap_label_indexers(data_obj, indexers, method=None, tolerance=None):

dim_indexers = get_dim_indexers(data_obj, indexers)
for dim, label in dim_indexers.items():
try:
if dim in data_obj.indexes:
index = data_obj.indexes[dim]
except KeyError:

coords_dtype = data_obj.coords[dim].dtype
label = maybe_cast_to_coords_dtype(label, coords_dtype)
idxr, new_idx = convert_label_indexer(index, label, dim,
method, tolerance)
pos_indexers[dim] = idxr
if new_idx is not None:
new_indexes[dim] = new_idx

elif dim in data_obj.coords and len(data_obj.coords[dim] == 1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be data_obj.coords[dim].ndim == 1 instead of len(data_obj.coords[dim] == 1)? The later is checking that the coordinate's first dimension has length 1.

# 1D non-dimension coord
index = data_obj.coords[dim].to_index()
(dim,) = data_obj[dim].dims

coords_dtype = data_obj.coords[dim].dtype
label = maybe_cast_to_coords_dtype(label, coords_dtype)
idxr, new_idx = convert_label_indexer(index, label, dim,
method, tolerance)
pos_indexers[dim] = idxr
if new_idx is not None:
new_indexes[dim] = new_idx
Comment on lines +280 to +286
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you refactor so this code block doesn't need to be repeated?

Every index is a 1D coordinate, so I think you could do this by switching around the order of these if checks, e.g.,

    if dim in data_obj.coords and data_obj.coords[dim].ndim == 1:
        if dim in data_obj.indexes:
            index = data_obj.indexes[dim]
        else:
            index = data_obj.coords[dim].to_index()
            (dim,) = data_obj[dim].dims

        coords_dtype = data_obj.coords[dim].dtype
        label = maybe_cast_to_coords_dtype(label, coords_dtype)
        idxr, new_idx = convert_label_indexer(index, label, dim,
                                              method, tolerance)
        pos_indexers[dim] = idxr
        if new_idx is not None:
            new_indexes[dim] = new_idx


else:
# no index for this dimension: reuse the provided labels
if method is not None or tolerance is not None:
raise ValueError(
Expand All @@ -264,13 +294,6 @@ def remap_label_indexers(data_obj, indexers, method=None, tolerance=None):
"an associated coordinate."
)
pos_indexers[dim] = label
else:
coords_dtype = data_obj.coords[dim].dtype
label = maybe_cast_to_coords_dtype(label, coords_dtype)
idxr, new_idx = convert_label_indexer(index, label, dim, method, tolerance)
pos_indexers[dim] = idxr
if new_idx is not None:
new_indexes[dim] = new_idx

return pos_indexers, new_indexes

Expand Down
23 changes: 19 additions & 4 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
pass


def create_test_data(seed=None):
def create_test_data(seed=None, dim_coords=True):
rs = np.random.RandomState(seed)
_vars = {
"var1": ["dim1", "dim2"],
Expand All @@ -66,9 +66,14 @@ def create_test_data(seed=None):
_dims = {"dim1": 8, "dim2": 9, "dim3": 10}

obj = Dataset()
obj["time"] = ("time", pd.date_range("2000-01-01", periods=20))
obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"]))
obj["dim3"] = ("dim3", list("abcdefghij"))
if dim_coords:
obj["time"] = ("time", pd.date_range("2000-01-01", periods=20))
obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"]))
obj["dim3"] = ("dim3", list("abcdefghij"))
else:
obj.coords["time_coord"] = ("time", pd.date_range("2000-01-01", periods=20))
obj.coords["coord2"] = ("dim2", 0.5 * np.arange(_dims["dim2"]))
obj.coords["coord3"] = ("dim3", list("abcdefghij"))
for v, dims in sorted(_vars.items()):
data = rs.normal(size=tuple(_dims[d] for d in dims))
obj[v] = (dims, data, {"foo": "variable"})
Expand Down Expand Up @@ -1368,6 +1373,16 @@ def test_sel_dataarray(self):
assert_equal(actual.drop_vars("new_dim"), expected)
assert np.allclose(actual["new_dim"].values, ind["new_dim"].values)

def test_sel_non_index_coord(self):
data = create_test_data(dim_coords=False)
int_slicers = {"dim2": slice(2),
"dim3": slice(3)}
loc_slicers = {
"coord2": slice(0, 0.5),
"coord3": slice("a", "c"),
}
assert_equal(data.isel(**int_slicers), data.sel(**loc_slicers))

def test_sel_dataarray_mindex(self):
midx = pd.MultiIndex.from_product([list("abc"), [0, 1]], names=("one", "two"))
mds = xr.Dataset(
Expand Down