From a4ba2bcdfff403557896b8e20feb682436d52c79 Mon Sep 17 00:00:00 2001
From: Deepak Cherian <deepak@cherian.net>
Date: Wed, 20 Nov 2024 20:39:01 -0700
Subject: [PATCH] Migrate to DaskIndexingAdapter

---
 xarray/core/computation.py | 38 +++++------------------------
 xarray/core/indexing.py    | 49 ++++++++++++++++++++++++++++++++------
 2 files changed, 48 insertions(+), 39 deletions(-)

diff --git a/xarray/core/computation.py b/xarray/core/computation.py
index 6becb21ffca..649eeb60282 100644
--- a/xarray/core/computation.py
+++ b/xarray/core/computation.py
@@ -2126,18 +2126,6 @@ def to_floatable(x: DataArray) -> DataArray:
         return to_floatable(data)
 
 
-def _apply_vectorized_indexer(indices, coord):
-    from xarray.core.indexing import (
-        VectorizedIndexer,
-        apply_indexer,
-        as_indexable,
-    )
-
-    return apply_indexer(
-        as_indexable(coord), VectorizedIndexer((indices.squeeze(axis=-1),))
-    )
-
-
 def _calc_idxminmax(
     *,
     array,
@@ -2182,28 +2170,14 @@ def _calc_idxminmax(
     indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)
 
     # Handle chunked arrays (e.g. dask).
+    coord = array[dim]._variable.to_base_variable()
     if is_chunked_array(array.data):
         chunkmanager = get_chunked_array_type(array.data)
-        chunked_coord = chunkmanager.from_array(array[dim].data, chunks=((-1,),))
-
-        if indx.ndim == 0:
-            out = chunked_coord[indx.data]
-        else:
-            out = chunkmanager.map_blocks(
-                _apply_vectorized_indexer,
-                indx.data[..., np.newaxis],
-                chunked_coord,
-                chunks=indx.data.chunks,
-                drop_axis=-1,
-                dtype=chunked_coord.dtype,
-            )
-        res = indx.copy(data=out)
-        # we need to attach back the dim name
-        res.name = dim
-    else:
-        res = array[dim][(indx,)]
-        # The dim is gone but we need to remove the corresponding coordinate.
-        del res.coords[dim]
+        coord_array = chunkmanager.from_array(
+            array[dim].data, chunks=((array.sizes[dim],),)
+        )
+        coord = coord.copy(data=coord_array)
+    res = indx._replace(coord[(indx.variable,)]).rename(dim)
 
     if skipna or (skipna is None and array.dtype.kind in na_dtypes):
         # Put the NaN values back in after removing them
diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py
index 0a7b94a53c7..fde90ee71d0 100644
--- a/xarray/core/indexing.py
+++ b/xarray/core/indexing.py
@@ -2,6 +2,7 @@
 
 import enum
 import functools
+import math
 import operator
 from collections import Counter, defaultdict
 from collections.abc import Callable, Hashable, Iterable, Mapping
@@ -472,12 +473,12 @@ def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.generic]], ...
         for k in key:
             if isinstance(k, slice):
                 k = as_integer_slice(k)
-            elif is_duck_dask_array(k):
-                raise ValueError(
-                    "Vectorized indexing with Dask arrays is not supported. "
-                    "Please pass a numpy array by calling ``.compute``. "
-                    "See https://github.com/dask/dask/issues/8958."
-                )
+            # elif is_duck_dask_array(k):
+            #     raise ValueError(
+            #         "Vectorized indexing with Dask arrays is not supported. "
+            #         "Please pass a numpy array by calling ``.compute``. "
+            #         "See https://github.com/dask/dask/issues/8958."
+            #     )
             elif is_duck_array(k):
                 if not np.issubdtype(k.dtype, np.integer):
                     raise TypeError(
@@ -1607,6 +1608,18 @@ def transpose(self, order):
         return xp.permute_dims(self.array, order)
 
 
+def _apply_vectorized_indexer_dask_wrapper(indices, coord):
+    from xarray.core.indexing import (
+        VectorizedIndexer,
+        apply_indexer,
+        as_indexable,
+    )
+
+    return apply_indexer(
+        as_indexable(coord), VectorizedIndexer((indices.squeeze(axis=-1),))
+    )
+
+
 class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
     """Wrap a dask array to support explicit indexing."""
 
@@ -1630,7 +1643,29 @@ def _oindex_get(self, indexer: OuterIndexer):
             return value
 
     def _vindex_get(self, indexer: VectorizedIndexer):
-        return self.array.vindex[indexer.tuple]
+        try:
+            return self.array.vindex[indexer.tuple]
+        except IndexError as e:
+            # TODO: upstream to dask
+            has_dask = any(is_duck_dask_array(i) for i in indexer.tuple)
+            if not has_dask or (has_dask and len(indexer.tuple) > 1):
+                raise e
+            if math.prod(self.array.numblocks) > 1 or self.array.ndim > 1:
+                raise e
+            (idxr,) = indexer.tuple
+            if idxr.ndim == 0:
+                return self.array[idxr.data]
+            else:
+                import dask.array
+
+                return dask.array.map_blocks(
+                    _apply_vectorized_indexer_dask_wrapper,
+                    idxr[..., np.newaxis],
+                    self.array,
+                    chunks=idxr.chunks,
+                    drop_axis=-1,
+                    dtype=self.array.dtype,
+                )
 
     def __getitem__(self, indexer: ExplicitIndexer):
         self._check_and_raise_if_non_basic_indexer(indexer)