Skip to content

Commit

Permalink
Make argmin/max work lazy with dask (#3244)
Browse files Browse the repository at this point in the history
* Make argmin/max work lazy with dask (#3237).

* dask: Testing number of computes on reduce methods.

* what's new updated

* Fix typo

Co-Authored-By: Stephan Hoyer <[email protected]>

* Be more explicit.

Co-Authored-By: Stephan Hoyer <[email protected]>

* More explicit raise_if_dask_computes

* nanargmin/max: only set fill_value when needed
  • Loading branch information
ulijh authored and dcherian committed Sep 6, 2019
1 parent 5c6aebc commit 0a046db
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 30 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ Bug fixes
- Fix error that arises when using open_mfdataset on a series of netcdf files
having differing values for a variable attribute of type list. (:issue:`3034`)
By `Hasan Ahmad <https://github.com/HasanAhmadQ7>`_.
- Prevent :py:meth:`~xarray.DataArray.argmax` and :py:meth:`~xarray.DataArray.argmin` from calling
dask compute (:issue:`3237`). By `Ulrich Herter <https://github.com/ulijh>`_.

.. _whats-new.0.12.3:

Expand Down
29 changes: 6 additions & 23 deletions xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,38 +88,21 @@ def nanmax(a, axis=None, out=None):


def nanargmin(a, axis=None):
fill_value = dtypes.get_pos_infinity(a.dtype)
if a.dtype.kind == "O":
fill_value = dtypes.get_pos_infinity(a.dtype)
return _nan_argminmax_object("argmin", fill_value, a, axis=axis)
a, mask = _replace_nan(a, fill_value)
if isinstance(a, dask_array_type):
res = dask_array.argmin(a, axis=axis)
else:
res = np.argmin(a, axis=axis)

if mask is not None:
mask = mask.all(axis=axis)
if mask.any():
raise ValueError("All-NaN slice encountered")
return res
module = dask_array if isinstance(a, dask_array_type) else nputils
return module.nanargmin(a, axis=axis)


def nanargmax(a, axis=None):
fill_value = dtypes.get_neg_infinity(a.dtype)
if a.dtype.kind == "O":
fill_value = dtypes.get_neg_infinity(a.dtype)
return _nan_argminmax_object("argmax", fill_value, a, axis=axis)

a, mask = _replace_nan(a, fill_value)
if isinstance(a, dask_array_type):
res = dask_array.argmax(a, axis=axis)
else:
res = np.argmax(a, axis=axis)

if mask is not None:
mask = mask.all(axis=axis)
if mask.any():
raise ValueError("All-NaN slice encountered")
return res
module = dask_array if isinstance(a, dask_array_type) else nputils
return module.nanargmax(a, axis=axis)


def nansum(a, axis=None, dtype=None, out=None, min_count=None):
Expand Down
2 changes: 2 additions & 0 deletions xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,5 @@ def f(values, axis=None, **kwargs):
nanprod = _create_bottleneck_method("nanprod")
nancumsum = _create_bottleneck_method("nancumsum")
nancumprod = _create_bottleneck_method("nancumprod")
nanargmin = _create_bottleneck_method("nanargmin")
nanargmax = _create_bottleneck_method("nanargmax")
54 changes: 47 additions & 7 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,49 @@
dd = pytest.importorskip("dask.dataframe")


class CountingScheduler:
""" Simple dask scheduler counting the number of computes.
Reference: https://stackoverflow.com/questions/53289286/ """

def __init__(self, max_computes=0):
self.total_computes = 0
self.max_computes = max_computes

def __call__(self, dsk, keys, **kwargs):
self.total_computes += 1
if self.total_computes > self.max_computes:
raise RuntimeError(
"Too many computes. Total: %d > max: %d."
% (self.total_computes, self.max_computes)
)
return dask.get(dsk, keys, **kwargs)


def _set_dask_scheduler(scheduler=dask.get):
""" Backwards compatible way of setting scheduler. """
if LooseVersion(dask.__version__) >= LooseVersion("0.18.0"):
return dask.config.set(scheduler=scheduler)
return dask.set_options(get=scheduler)


def raise_if_dask_computes(max_computes=0):
scheduler = CountingScheduler(max_computes)
return _set_dask_scheduler(scheduler)


def test_raise_if_dask_computes():
data = da.from_array(np.random.RandomState(0).randn(4, 6), chunks=(2, 2))
with raises_regex(RuntimeError, "Too many computes"):
with raise_if_dask_computes():
data.compute()


class DaskTestCase:
def assertLazyAnd(self, expected, actual, test):

with (
dask.config.set(scheduler="single-threaded")
if LooseVersion(dask.__version__) >= LooseVersion("0.18.0")
else dask.set_options(get=dask.get)
):
with _set_dask_scheduler(dask.get):
# dask.get is the syncronous scheduler, which get's set also by
# dask.config.set(scheduler="syncronous") in current versions.
test(actual, expected)

if isinstance(actual, Dataset):
Expand Down Expand Up @@ -174,7 +209,12 @@ def test_reduce(self):
v = self.lazy_var
self.assertLazyAndAllClose(u.mean(), v.mean())
self.assertLazyAndAllClose(u.std(), v.std())
self.assertLazyAndAllClose(u.argmax(dim="x"), v.argmax(dim="x"))
with raise_if_dask_computes():
actual = v.argmax(dim="x")
self.assertLazyAndAllClose(u.argmax(dim="x"), actual)
with raise_if_dask_computes():
actual = v.argmin(dim="x")
self.assertLazyAndAllClose(u.argmin(dim="x"), actual)
self.assertLazyAndAllClose((u > 1).any(), (v > 1).any())
self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x"))
with raises_regex(NotImplementedError, "dask"):
Expand Down

0 comments on commit 0a046db

Please sign in to comment.