diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 96f0ba9a4a6..620617c127a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -73,7 +73,7 @@ New Features for xarray objects. Note that xarray objects with a dask.array backend already used deterministic hashing in previous releases; this change implements it when whole xarray objects are embedded in a dask graph, e.g. when :meth:`DataArray.map` is - invoked. (:issue:`3378`, :pull:`3446`) + invoked. (:issue:`3378`, :pull:`3446`, :pull:`3515`) By `Deepak Cherian `_ and `Guido Imperiale `_. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 5e164f420c8..a192fe08cee 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -755,7 +755,9 @@ def reset_coords( return dataset def __dask_tokenize__(self): - return (type(self), self._variable, self._coords, self._name) + from dask.base import normalize_token + + return normalize_token((type(self), self._variable, self._coords, self._name)) def __dask_graph__(self): return self._to_temp_dataset().__dask_graph__() diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index dc5a315e72a..fe8abdc4b95 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -652,7 +652,11 @@ def load(self, **kwargs) -> "Dataset": return self def __dask_tokenize__(self): - return (type(self), self._variables, self._coord_names, self._attrs) + from dask.base import normalize_token + + return normalize_token( + (type(self), self._variables, self._coord_names, self._attrs) + ) def __dask_graph__(self): graphs = {k: v.__dask_graph__() for k, v in self.variables.items()} diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 916df75b3e0..f842a4a9428 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -393,7 +393,9 @@ def compute(self, **kwargs): def __dask_tokenize__(self): # Use v.data, instead of v._data, in order to cope with the wrappers # around NetCDF and the like - return type(self), self._dims, self.data, self._attrs + from dask.base import normalize_token + + return normalize_token((type(self), self._dims, self.data, self._attrs)) def __dask_graph__(self): if isinstance(self._data, dask_array_type): @@ -1973,8 +1975,10 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): self._data = PandasIndexAdapter(self._data) def __dask_tokenize__(self): + from dask.base import normalize_token + # Don't waste time converting pd.Index to np.ndarray - return (type(self), self._dims, self._data.array, self._attrs) + return normalize_token((type(self), self._dims, self._data.array, self._attrs)) def load(self): # data is already loaded into memory for IndexVariable diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index fa8ae9991d7..43b788153bc 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1283,6 +1283,32 @@ def test_token_identical(obj, transform): ) +def test_recursive_token(): + """Test that tokenization is invoked recursively, and doesn't just rely on the + output of str() + """ + a = np.ones(10000) + b = np.ones(10000) + b[5000] = 2 + assert str(a) == str(b) + assert dask.base.tokenize(a) != dask.base.tokenize(b) + + # Test DataArray and Variable + da_a = DataArray(a) + da_b = DataArray(b) + assert dask.base.tokenize(da_a) != dask.base.tokenize(da_b) + + # Test Dataset + ds_a = da_a.to_dataset(name="x") + ds_b = da_b.to_dataset(name="x") + assert dask.base.tokenize(ds_a) != dask.base.tokenize(ds_b) + + # Test IndexVariable + da_a = DataArray(a, dims=["x"], coords={"x": a}) + da_b = DataArray(a, dims=["x"], coords={"x": b}) + assert dask.base.tokenize(da_a) != dask.base.tokenize(da_b) + + @requires_scipy_or_netCDF4 def test_normalize_token_with_backend(map_ds): with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp_file: diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index a31da162487..a02fef2faeb 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -856,6 +856,10 @@ def test_dask_token(): import dask s = sparse.COO.from_numpy(np.array([0, 0, 1, 2])) + + # https://github.com/pydata/sparse/issues/300 + s.__dask_tokenize__ = lambda: dask.base.normalize_token(s.__dict__) + a = DataArray(s) t1 = dask.base.tokenize(a) t2 = dask.base.tokenize(a)