Skip to content

Commit

Permalink
Let vindex accept a Dask Array indexer under certain conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl committed Jan 3, 2025
1 parent c529e43 commit af36299
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
21 changes: 17 additions & 4 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2037,10 +2037,19 @@ def _vindex(self, key):
"Use normal slicing instead when only using slices. Got: {}".format(key)
)
elif any(is_dask_collection(k) for k in key):
raise IndexError(
"vindex does not support indexing with dask objects. Call compute "
"on the indexer first to get an evalurated array. Got: {}".format(key)
)
if math.prod(self.numblocks) == 1 and len(key) == 1 and self.ndim == 1:
idxr = key[0]
# we can broadcast in this case
return idxr.map_blocks(
_numpy_vindex, self, dtype=self.dtype, chunks=idxr.chunks
)
else:
raise IndexError(
"vindex does not support indexing with dask objects. Call compute "
"on the indexer first to get an evalurated array. Got: {}".format(
key
)
)
return _vindex(self, *key)

@property
Expand Down Expand Up @@ -6075,4 +6084,8 @@ def ravel(self) -> list[Array]:
return [self[idx] for idx in np.ndindex(self.shape)]


def _numpy_vindex(indexer, arr):
return arr[indexer]


from dask.array.blockwise import blockwise
9 changes: 9 additions & 0 deletions dask/array/tests/test_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,3 +1111,12 @@ def test_minimal_dtype_doesnt_overflow():
ib = np.zeros(1980, dtype=bool)
ib[1560:1860] = True
assert_eq(dx[ib], x[ib])


def test_vindex_with_dask_array():
arr = np.array([0.2, 0.4, 0.6])
darr = da.from_array(arr, chunks=-1)

indexer = np.random.randint(0, 3, 8).reshape(4, 2).astype(int)
dindexer = da.from_array(indexer, chunks=(2, 2))
assert_eq(darr.vindex[dindexer], arr[indexer])

0 comments on commit af36299

Please sign in to comment.