diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 70994a36ac8..f5cad4db13a 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -18,7 +18,7 @@ is_duck_dask_array, sparse_array_type, ) -from .utils import maybe_cast_to_coords_dtype +from .utils import is_duck_array, maybe_cast_to_coords_dtype def expanded_indexer(key, ndim): @@ -307,7 +307,7 @@ def __init__(self, key): for k in key: if isinstance(k, slice): k = as_integer_slice(k) - elif isinstance(k, np.ndarray): + elif is_duck_array(k): if not np.issubdtype(k.dtype, np.integer): raise TypeError( f"invalid indexer array, does not have integer dtype: {k!r}" @@ -320,7 +320,7 @@ def __init__(self, key): "invalid indexer key: ndarray arguments " f"have different numbers of dimensions: {ndims}" ) - k = np.asarray(k, dtype=np.int64) + k = k.astype(np.int64) else: raise TypeError( f"unexpected indexer type for {type(self).__name__}: {k!r}" diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 533f4a0cd62..a0263f1a2c3 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -9,6 +9,8 @@ from . import IndexerMaker, ReturnItem, assert_array_equal +da = pytest.importorskip("dask.array") + B = IndexerMaker(indexing.BasicIndexer) @@ -729,3 +731,16 @@ def test_indexing_1d_object_array() -> None: expected = DataArray(expected_data) assert [actual.data.item()] == [expected.data.item()] + + +def test_indexing_dask_array(): + da = DataArray( + np.ones(10 * 3 * 3).reshape((10, 3, 3)), + dims=("time", "x", "y"), + ).chunk(dict(time=-1, x=1, y=1)) + da[{"time": 9}] = 42 + + idx = da.argmax("time") + actual = da.isel(time=idx) + + assert np.all(actual == 42)