From af28c6b02fac08494f5d9ae2718d68a084d93949 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 5 Nov 2019 15:41:13 +0000 Subject: [PATCH] Optimize dask array equality checks. (#3453) * Optimize dask array equality checks. Dask arrays with the same graph have the same name. We can use this to quickly compare dask-backed variables without computing. Fixes #3068 and #3311 * better docstring * review suggestions. * add concat test * update whats new * Add identity check to lazy_array_equiv * pep8 * bugfix. --- doc/whats-new.rst | 3 + xarray/core/concat.py | 56 ++++++++++++------ xarray/core/duck_array_ops.py | 62 ++++++++++++++----- xarray/core/merge.py | 19 ++++-- xarray/core/variable.py | 14 +++-- xarray/tests/test_dask.py | 108 +++++++++++++++++++++++++++++++++- 6 files changed, 217 insertions(+), 45 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c117382f23f..dcaab011e67 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -70,6 +70,9 @@ Bug fixes but cloudpickle isn't (:issue:`3401`) by `Rhys Doyle `_ - Fix grouping over variables with NaNs. (:issue:`2383`, :pull:`3406`). By `Deepak Cherian `_. +- Use dask names to compare dask objects prior to comparing values after computation. + (:issue:`3068`, :issue:`3311`, :issue:`3454`, :pull:`3453`). + By `Deepak Cherian `_. - Sync with cftime by removing `dayofwk=-1` for cftime>=1.0.4. By `Anderson Banihirwe `_. - Fix :py:meth:`xarray.core.groupby.DataArrayGroupBy.reduce` and diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 0d19990bdd0..c26153eb0d8 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -2,6 +2,7 @@ from . import dtypes, utils from .alignment import align +from .duck_array_ops import lazy_array_equiv from .merge import _VALID_COMPAT, unique_variable from .variable import IndexVariable, Variable, as_variable from .variable import concat as concat_vars @@ -189,26 +190,43 @@ def process_subset_opt(opt, subset): # all nonindexes that are not the same in each dataset for k in getattr(datasets[0], subset): if k not in concat_over: - # Compare the variable of all datasets vs. the one - # of the first dataset. Perform the minimum amount of - # loads in order to avoid multiple loads from disk - # while keeping the RAM footprint low. - v_lhs = datasets[0].variables[k].load() - # We'll need to know later on if variables are equal. - computed = [] - for ds_rhs in datasets[1:]: - v_rhs = ds_rhs.variables[k].compute() - computed.append(v_rhs) - if not getattr(v_lhs, compat)(v_rhs): - concat_over.add(k) - equals[k] = False - # computed variables are not to be re-computed - # again in the future - for ds, v in zip(datasets[1:], computed): - ds.variables[k].data = v.data + equals[k] = None + variables = [ds.variables[k] for ds in datasets] + # first check without comparing values i.e. no computes + for var in variables[1:]: + equals[k] = getattr(variables[0], compat)( + var, equiv=lazy_array_equiv + ) + if equals[k] is not True: + # exit early if we know these are not equal or that + # equality cannot be determined i.e. one or all of + # the variables wraps a numpy array break - else: - equals[k] = True + + if equals[k] is False: + concat_over.add(k) + + elif equals[k] is None: + # Compare the variable of all datasets vs. the one + # of the first dataset. Perform the minimum amount of + # loads in order to avoid multiple loads from disk + # while keeping the RAM footprint low. + v_lhs = datasets[0].variables[k].load() + # We'll need to know later on if variables are equal. + computed = [] + for ds_rhs in datasets[1:]: + v_rhs = ds_rhs.variables[k].compute() + computed.append(v_rhs) + if not getattr(v_lhs, compat)(v_rhs): + concat_over.add(k) + equals[k] = False + # computed variables are not to be re-computed + # again in the future + for ds, v in zip(datasets[1:], computed): + ds.variables[k].data = v.data + break + else: + equals[k] = True elif opt == "all": concat_over.update( diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index d943788c434..71e79335c3d 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -174,14 +174,42 @@ def as_shared_dtype(scalars_or_arrays): return [x.astype(out_type, copy=False) for x in arrays] -def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): - """Like np.allclose, but also allows values to be NaN in both arrays +def lazy_array_equiv(arr1, arr2): + """Like array_equal, but doesn't actually compare values. + Returns True when arr1, arr2 identical or their dask names are equal. + Returns False when shapes are not equal. + Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays; + or their dask names are not equal """ + if arr1 is arr2: + return True arr1 = asarray(arr1) arr2 = asarray(arr2) if arr1.shape != arr2.shape: return False - return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) + if ( + dask_array + and isinstance(arr1, dask_array.Array) + and isinstance(arr2, dask_array.Array) + ): + # GH3068 + if arr1.name == arr2.name: + return True + else: + return None + return None + + +def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): + """Like np.allclose, but also allows values to be NaN in both arrays + """ + arr1 = asarray(arr1) + arr2 = asarray(arr2) + lazy_equiv = lazy_array_equiv(arr1, arr2) + if lazy_equiv is None: + return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) + else: + return lazy_equiv def array_equiv(arr1, arr2): @@ -189,12 +217,14 @@ def array_equiv(arr1, arr2): """ arr1 = asarray(arr1) arr2 = asarray(arr2) - if arr1.shape != arr2.shape: - return False - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "In the future, 'NAT == x'") - flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2)) - return bool(flag_array.all()) + lazy_equiv = lazy_array_equiv(arr1, arr2) + if lazy_equiv is None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "In the future, 'NAT == x'") + flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2)) + return bool(flag_array.all()) + else: + return lazy_equiv def array_notnull_equiv(arr1, arr2): @@ -203,12 +233,14 @@ def array_notnull_equiv(arr1, arr2): """ arr1 = asarray(arr1) arr2 = asarray(arr2) - if arr1.shape != arr2.shape: - return False - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "In the future, 'NAT == x'") - flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2) - return bool(flag_array.all()) + lazy_equiv = lazy_array_equiv(arr1, arr2) + if lazy_equiv is None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "In the future, 'NAT == x'") + flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2) + return bool(flag_array.all()) + else: + return lazy_equiv def count(data, axis=None): diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 389ceb155f7..daf0c3b059f 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -19,6 +19,7 @@ from . import dtypes, pdcompat from .alignment import deep_align +from .duck_array_ops import lazy_array_equiv from .utils import Frozen, dict_equiv from .variable import Variable, as_variable, assert_unique_multiindex_level_names @@ -123,16 +124,24 @@ def unique_variable( combine_method = "fillna" if equals is None: - out = out.compute() + # first check without comparing values i.e. no computes for var in variables[1:]: - equals = getattr(out, compat)(var) - if not equals: + equals = getattr(out, compat)(var, equiv=lazy_array_equiv) + if equals is not True: break + if equals is None: + # now compare values with minimum number of computes + out = out.compute() + for var in variables[1:]: + equals = getattr(out, compat)(var) + if not equals: + break + if not equals: raise MergeError( - "conflicting values for variable {!r} on objects to be combined. " - "You can skip this check by specifying compat='override'.".format(name) + f"conflicting values for variable {name!r} on objects to be combined. " + "You can skip this check by specifying compat='override'." ) if combine_method: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 117ab85ae65..916df75b3e0 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1236,7 +1236,9 @@ def transpose(self, *dims) -> "Variable": dims = self.dims[::-1] dims = tuple(infix_dims(dims, self.dims)) axes = self.get_axis_num(dims) - if len(dims) < 2: # no need to transpose if only one dimension + if len(dims) < 2 or dims == self.dims: + # no need to transpose if only one dimension + # or dims are in same order return self.copy(deep=False) data = as_indexable(self._data).transpose(axes) @@ -1595,22 +1597,24 @@ def broadcast_equals(self, other, equiv=duck_array_ops.array_equiv): return False return self.equals(other, equiv=equiv) - def identical(self, other): + def identical(self, other, equiv=duck_array_ops.array_equiv): """Like equals, but also checks attributes. """ try: - return utils.dict_equiv(self.attrs, other.attrs) and self.equals(other) + return utils.dict_equiv(self.attrs, other.attrs) and self.equals( + other, equiv=equiv + ) except (TypeError, AttributeError): return False - def no_conflicts(self, other): + def no_conflicts(self, other, equiv=duck_array_ops.array_notnull_equiv): """True if the intersection of two Variable's non-null data is equal; otherwise false. Variables can thus still be equal if there are locations where either, or both, contain NaN values. """ - return self.broadcast_equals(other, equiv=duck_array_ops.array_notnull_equiv) + return self.broadcast_equals(other, equiv=equiv) def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): """Compute the qth quantile of the data along the specified dimension. diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index c4323d1d317..34115b29b23 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -24,6 +24,7 @@ raises_regex, requires_scipy_or_netCDF4, ) +from ..core.duck_array_ops import lazy_array_equiv from .test_backends import create_tmp_file dask = pytest.importorskip("dask") @@ -428,7 +429,53 @@ def test_concat_loads_variables(self): out.compute() assert kernel_call_count == 24 - # Finally, test that riginals are unaltered + # Finally, test that originals are unaltered + assert ds1["d"].data is d1 + assert ds1["c"].data is c1 + assert ds2["d"].data is d2 + assert ds2["c"].data is c2 + assert ds3["d"].data is d3 + assert ds3["c"].data is c3 + + # now check that concat() is correctly using dask name equality to skip loads + out = xr.concat( + [ds1, ds1, ds1], dim="n", data_vars="different", coords="different" + ) + assert kernel_call_count == 24 + # variables are not loaded in the output + assert isinstance(out["d"].data, dask.array.Array) + assert isinstance(out["c"].data, dask.array.Array) + + out = xr.concat( + [ds1, ds1, ds1], dim="n", data_vars=[], coords=[], compat="identical" + ) + assert kernel_call_count == 24 + # variables are not loaded in the output + assert isinstance(out["d"].data, dask.array.Array) + assert isinstance(out["c"].data, dask.array.Array) + + out = xr.concat( + [ds1, ds2.compute(), ds3], + dim="n", + data_vars="all", + coords="different", + compat="identical", + ) + # c1,c3 must be computed for comparison since c2 is numpy; + # d2 is computed too + assert kernel_call_count == 28 + + out = xr.concat( + [ds1, ds2.compute(), ds3], + dim="n", + data_vars="all", + coords="all", + compat="identical", + ) + # no extra computes + assert kernel_call_count == 30 + + # Finally, test that originals are unaltered assert ds1["d"].data is d1 assert ds1["c"].data is c1 assert ds2["d"].data is d2 @@ -1142,6 +1189,19 @@ def test_make_meta(map_ds): assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim +def test_identical_coords_no_computes(): + lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x")) + a = xr.DataArray( + da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2} + ) + b = xr.DataArray( + da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2} + ) + with raise_if_dask_computes(): + c = a + b + assert_identical(c, a) + + @pytest.mark.parametrize( "obj", [make_da(), make_da().compute(), make_ds(), make_ds().compute()] ) @@ -1229,3 +1289,49 @@ def test_normalize_token_with_backend(map_ds): map_ds.to_netcdf(tmp_file) read = xr.open_dataset(tmp_file) assert not dask.base.tokenize(map_ds) == dask.base.tokenize(read) + + +@pytest.mark.parametrize( + "compat", ["broadcast_equals", "equals", "identical", "no_conflicts"] +) +def test_lazy_array_equiv_variables(compat): + var1 = xr.Variable(("y", "x"), da.zeros((10, 10), chunks=2)) + var2 = xr.Variable(("y", "x"), da.zeros((10, 10), chunks=2)) + var3 = xr.Variable(("y", "x"), da.zeros((20, 10), chunks=2)) + + with raise_if_dask_computes(): + assert getattr(var1, compat)(var2, equiv=lazy_array_equiv) + # values are actually equal, but we don't know that till we compute, return None + with raise_if_dask_computes(): + assert getattr(var1, compat)(var2 / 2, equiv=lazy_array_equiv) is None + + # shapes are not equal, return False without computes + with raise_if_dask_computes(): + assert getattr(var1, compat)(var3, equiv=lazy_array_equiv) is False + + # if one or both arrays are numpy, return None + assert getattr(var1, compat)(var2.compute(), equiv=lazy_array_equiv) is None + assert ( + getattr(var1.compute(), compat)(var2.compute(), equiv=lazy_array_equiv) is None + ) + + with raise_if_dask_computes(): + assert getattr(var1, compat)(var2.transpose("y", "x")) + + +@pytest.mark.parametrize( + "compat", ["broadcast_equals", "equals", "identical", "no_conflicts"] +) +def test_lazy_array_equiv_merge(compat): + da1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x")) + da2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x")) + da3 = xr.DataArray(da.ones((20, 10), chunks=2), dims=("y", "x")) + + with raise_if_dask_computes(): + xr.merge([da1, da2], compat=compat) + # shapes are not equal; no computes necessary + with raise_if_dask_computes(max_computes=0): + with pytest.raises(ValueError): + xr.merge([da1, da3], compat=compat) + with raise_if_dask_computes(max_computes=2): + xr.merge([da1, da2 / 2], compat=compat)