diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9d3e64badb8..5f1eef09847 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -60,6 +60,8 @@ Internal Changes - Use Python 3.6 idioms throughout the codebase. (:pull:3419) By `Maximilian Roos `_ +- Implement :py:func:`__dask_tokenize__` for xarray objects. + By `Deepak Cherian `_ .. _whats-new.0.14.0: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 5fccb9236e8..b9cff19009d 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -754,6 +754,9 @@ def reset_coords( dataset[self.name] = self.variable return dataset + def __dask_tokenize__(self): + return (DataArray, self._variable, self._coords, self._name) + def __dask_graph__(self): return self._to_temp_dataset().__dask_graph__() diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 12d5cbdc9f3..3f2a9b2bb3c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -648,6 +648,9 @@ def load(self, **kwargs) -> "Dataset": return self + def __dask_tokenize__(self): + return (Dataset, self._variables, self._coord_names, self._attrs) + def __dask_graph__(self): graphs = {k: v.__dask_graph__() for k, v in self.variables.items()} graphs = {k: v for k, v in graphs.items() if v is not None} diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 37672cd82d9..bc6d3695ae3 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -389,6 +389,9 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) + def __dask_tokenize__(self): + return Variable, self._dims, self.data, self._attrs + def __dask_graph__(self): if isinstance(self._data, dask_array_type): return self._data.__dask_graph__() @@ -1961,6 +1964,9 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): if not isinstance(self._data, PandasIndexAdapter): self._data = PandasIndexAdapter(self._data) + def __dask_tokenize__(self): + return (IndexVariable, self._dims, self._data.array, self._attrs) + def load(self): # data is already loaded into memory for IndexVariable return self diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index ae8f43cb66d..091a63f471d 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -22,6 +22,7 @@ assert_identical, raises_regex, ) +from .test_backends import create_tmp_file dask = pytest.importorskip("dask") da = pytest.importorskip("dask.array") @@ -1135,3 +1136,57 @@ def test_make_meta(map_ds): for variable in map_ds.data_vars: assert variable in meta.data_vars assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +@pytest.mark.parametrize( + "transform", + [ + lambda x: x.reset_coords(), + lambda x: x.reset_coords(drop=True), + lambda x: x.isel(x=1), + lambda x: x.attrs.update(new_attrs=1), + lambda x: x.assign_coords(cxy=1), + lambda x: x.rename({"x": "xnew"}), + lambda x: x.rename({"cxy": "cxynew"}), + ], +) +def test_normalize_token_not_identical(obj, transform): + with raise_if_dask_computes(): + assert not dask.base.tokenize(obj) == dask.base.tokenize(transform(obj)) + assert not dask.base.tokenize(obj.compute()) == dask.base.tokenize( + transform(obj.compute()) + ) + + +@pytest.mark.parametrize("transform", [lambda x: x, lambda x: x.compute()]) +def test_normalize_differently_when_data_changes(transform): + obj = transform(make_ds()) + new = obj.copy(deep=True) + new["a"] *= 2 + with raise_if_dask_computes(): + assert not dask.base.tokenize(obj) == dask.base.tokenize(new) + + obj = transform(make_da()) + new = obj.copy(deep=True) + new *= 2 + with raise_if_dask_computes(): + assert not dask.base.tokenize(obj) == dask.base.tokenize(new) + + +@pytest.mark.parametrize( + "transform", [lambda x: x, lambda x: x.copy(), lambda x: x.copy(deep=True)] +) +@pytest.mark.parametrize( + "obj", [make_da(), make_ds(), make_da().indexes["x"], make_ds().variables["a"]] +) +def test_normalize_token_identical(obj, transform): + with raise_if_dask_computes(): + assert dask.base.tokenize(obj) == dask.base.tokenize(transform(obj)) + + +def test_normalize_token_netcdf_backend(map_ds): + with create_tmp_file() as tmp_file: + map_ds.to_netcdf(tmp_file) + read = xr.open_dataset(tmp_file) + assert not dask.base.tokenize(map_ds) == dask.base.tokenize(read) diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index bd26b96f6d4..1cf915151ab 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -22,6 +22,7 @@ ) sparse = pytest.importorskip("sparse") +dask = pytest.importorskip("dask") def assert_sparse_equal(a, b): @@ -849,3 +850,14 @@ def test_chunk(): dsc = ds.chunk(2) assert dsc.chunks == {"dim_0": (2, 2)} assert_identical(dsc, ds) + + +def test_normalize_token(): + s = sparse.COO.from_numpy(np.array([0, 0, 1, 2])) + a = DataArray(s) + dask.base.tokenize(a) + assert isinstance(a.data, sparse.COO) + + ac = a.chunk(2) + dask.base.tokenize(ac) + assert isinstance(ac.data._meta, sparse.COO)