From 678bf7f9a355fa1e69f9a9ea832a9a995d5e0c50 Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Thu, 16 Jun 2022 11:27:01 -0400 Subject: [PATCH] Replace .ds with immutable DatasetView https://github.com/xarray-contrib/datatree/pull/99 * sketching out changes needed to integrate variables into DataTree * fixed some other basic conflicts * fix mypy errors * can create basic datatree node objects again * child-variable name collisions dectected correctly * in-progres * add _replace method * updated tests to assert identical instead of check .ds is expected_ds * refactor .ds setter to use _replace * refactor init to use _replace * refactor test tree to avoid init * attempt at copy methods * rewrote implementation of .copy method * xfailing test for deepcopying * pseudocode implementation of DatasetView * Revert "pseudocode implementation of DatasetView" This reverts commit 52ef23baaa4b6892cad2d69c61b43db831044630. * pseudocode implementation of DatasetView * removed duplicated implementation of copy * reorganise API docs * expose data_vars, coords etc. properties * try except with calculate_dimensions private import * add keys/values/items methods * don't use has_data when .variables would do * change asserts to not fail just because of differing types * full sketch of how DatasetView could work * added tests for DatasetView * remove commented pseudocode * explanation of basic properties * add data structures page to index * revert adding documentation in favour of that going in a different PR * correct deepcopy tests * use .data_vars in copy tests * add test for arithmetic with .ds * remove reference to wrapping node in DatasetView * clarify type through renaming variables * remove test for out-of-node access * make imports depend on most recent version of xarray Co-authored-by: Mattia Almansi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove try except for internal import * depend on latest pre-release of xarray * correct name of version * xarray pre-release under pip in ci envs * correct methods * whatsnews * fix fixture in test * whatsnew * improve docstrings * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: Mattia Almansi Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- datatree/datatree.py | 141 +++++++++++++++++++++++++++++++- datatree/mapping.py | 6 +- datatree/tests/test_datatree.py | 83 +++++++++++++------ docs/source/whats-new.rst | 8 ++ 4 files changed, 204 insertions(+), 34 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index e5a4bd4..116e723 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -19,6 +19,7 @@ Set, Tuple, Union, + overload, ) import pandas as pd @@ -86,6 +87,135 @@ def _check_for_name_collisions( ) +class DatasetView(Dataset): + """ + An immutable Dataset-like view onto the data in a single DataTree node. + + In-place operations modifying this object should raise an AttributeError. + + Operations returning a new result will return a new xarray.Dataset object. + This includes all API on Dataset, which will be inherited. + + This requires overriding all inherited private constructors. + """ + + # TODO what happens if user alters (in-place) a DataArray they extracted from this object? + + def __init__( + self, + data_vars: Mapping[Any, Any] = None, + coords: Mapping[Any, Any] = None, + attrs: Mapping[Any, Any] = None, + ): + raise AttributeError("DatasetView objects are not to be initialized directly") + + @classmethod + def _from_node( + cls, + wrapping_node: DataTree, + ) -> DatasetView: + """Constructor, using dataset attributes from wrapping node""" + + obj: DatasetView = object.__new__(cls) + obj._variables = wrapping_node._variables + obj._coord_names = wrapping_node._coord_names + obj._dims = wrapping_node._dims + obj._indexes = wrapping_node._indexes + obj._attrs = wrapping_node._attrs + obj._close = wrapping_node._close + obj._encoding = wrapping_node._encoding + + return obj + + def __setitem__(self, key, val) -> None: + raise AttributeError( + "Mutation of the DatasetView is not allowed, please use __setitem__ on the wrapping DataTree node, " + "or use `DataTree.to_dataset()` if you want a mutable dataset" + ) + + def update(self, other) -> None: + raise AttributeError( + "Mutation of the DatasetView is not allowed, please use .update on the wrapping DataTree node, " + "or use `DataTree.to_dataset()` if you want a mutable dataset" + ) + + # FIXME https://github.com/python/mypy/issues/7328 + @overload + def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[misc] + ... + + @overload + def __getitem__(self, key: Hashable) -> DataArray: # type: ignore[misc] + ... + + @overload + def __getitem__(self, key: Any) -> Dataset: + ... + + def __getitem__(self, key) -> DataArray: + # TODO call the `_get_item` method of DataTree to allow path-like access to contents of other nodes + # For now just call Dataset.__getitem__ + return Dataset.__getitem__(self, key) + + @classmethod + def _construct_direct( + cls, + variables: dict[Any, Variable], + coord_names: set[Hashable], + dims: dict[Any, int] = None, + attrs: dict = None, + indexes: dict[Any, Index] = None, + encoding: dict = None, + close: Callable[[], None] = None, + ) -> Dataset: + """ + Overriding this method (along with ._replace) and modifying it to return a Dataset object + should hopefully ensure that the return type of any method on this object is a Dataset. + """ + if dims is None: + dims = calculate_dimensions(variables) + if indexes is None: + indexes = {} + obj = object.__new__(Dataset) + obj._variables = variables + obj._coord_names = coord_names + obj._dims = dims + obj._indexes = indexes + obj._attrs = attrs + obj._close = close + obj._encoding = encoding + return obj + + def _replace( + self, + variables: dict[Hashable, Variable] = None, + coord_names: set[Hashable] = None, + dims: dict[Any, int] = None, + attrs: dict[Hashable, Any] | None | Default = _default, + indexes: dict[Hashable, Index] = None, + encoding: dict | None | Default = _default, + inplace: bool = False, + ) -> Dataset: + """ + Overriding this method (along with ._construct_direct) and modifying it to return a Dataset object + should hopefully ensure that the return type of any method on this object is a Dataset. + """ + + if inplace: + raise AttributeError("In-place mutation of the DatasetView is not allowed") + + return Dataset._replace( + self, + variables=variables, + coord_names=coord_names, + dims=dims, + attrs=attrs, + indexes=indexes, + encoding=encoding, + inplace=inplace, + ) + + class DataTree( TreeNode, MappedDatasetMethodsMixin, @@ -217,10 +347,13 @@ def parent(self: DataTree, new_parent: DataTree) -> None: self._set_parent(new_parent, self.name) @property - def ds(self) -> Dataset: - """The data in this node, returned as a Dataset.""" - # TODO change this to return only an immutable view onto this node's data (see GH https://github.com/xarray-contrib/datatree/issues/80) - return self.to_dataset() + def ds(self) -> DatasetView: + """ + An immutable Dataset-like view onto the data in this node. + + For a mutable Dataset containing the same data as in this node, use `.to_dataset()` instead. + """ + return DatasetView._from_node(self) @ds.setter def ds(self, data: Union[Dataset, DataArray] = None) -> None: diff --git a/datatree/mapping.py b/datatree/mapping.py index 94d2c74..344842b 100644 --- a/datatree/mapping.py +++ b/datatree/mapping.py @@ -189,10 +189,10 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: *args_as_tree_length_iterables, *list(kwargs_as_tree_length_iterables.values()), ): - node_args_as_datasets = [ + node_args_as_datasetviews = [ a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args] ] - node_kwargs_as_datasets = dict( + node_kwargs_as_datasetviews = dict( zip( [k for k in kwargs_as_tree_length_iterables.keys()], [ @@ -204,7 +204,7 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: # Now we can call func on the data in this particular set of corresponding nodes results = ( - func(*node_args_as_datasets, **node_kwargs_as_datasets) + func(*node_args_as_datasetviews, **node_kwargs_as_datasetviews) if not node_of_first_tree.is_empty else None ) diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index b488ed0..86a7858 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -4,7 +4,7 @@ import pytest import xarray as xr import xarray.testing as xrt -from xarray.tests import source_ndarray +from xarray.tests import create_test_data, source_ndarray import datatree.testing as dtt from datatree import DataTree @@ -16,7 +16,7 @@ def test_empty(self): assert dt.name == "root" assert dt.parent is None assert dt.children == {} - xrt.assert_identical(dt.ds, xr.Dataset()) + xrt.assert_identical(dt.to_dataset(), xr.Dataset()) def test_unnamed(self): dt = DataTree() @@ -66,7 +66,7 @@ class TestStoreDatasets: def test_create_with_data(self): dat = xr.Dataset({"a": 0}) john = DataTree(name="john", data=dat) - xrt.assert_identical(john.ds, dat) + xrt.assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): DataTree(name="mary", parent=john, data="junk") # noqa @@ -75,7 +75,7 @@ def test_set_data(self): john = DataTree(name="john") dat = xr.Dataset({"a": 0}) john.ds = dat - xrt.assert_identical(john.ds, dat) + xrt.assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): john.ds = "junk" @@ -100,16 +100,9 @@ def test_assign_when_already_child_with_variables_name(self): dt.ds = xr.Dataset({"a": 0}) dt.ds = xr.Dataset() + new_ds = dt.to_dataset().assign(a=xr.DataArray(0)) with pytest.raises(KeyError, match="names would collide"): - dt.ds = dt.ds.assign(a=xr.DataArray(0)) - - @pytest.mark.xfail - def test_update_when_already_child_with_variables_name(self): - # See issue https://github.com/xarray-contrib/datatree/issues/38 - dt = DataTree(name="root", data=None) - DataTree(name="a", data=None, parent=dt) - with pytest.raises(KeyError, match="names would collide"): - dt.ds["a"] = xr.DataArray(0) + dt.ds = new_ds class TestGet: @@ -275,13 +268,13 @@ def test_setitem_new_empty_node(self): john["mary"] = DataTree() mary = john["mary"] assert isinstance(mary, DataTree) - xrt.assert_identical(mary.ds, xr.Dataset()) + xrt.assert_identical(mary.to_dataset(), xr.Dataset()) def test_setitem_overwrite_data_in_node_with_none(self): john = DataTree(name="john") mary = DataTree(name="mary", parent=john, data=xr.Dataset()) john["mary"] = DataTree() - xrt.assert_identical(mary.ds, xr.Dataset()) + xrt.assert_identical(mary.to_dataset(), xr.Dataset()) john.ds = xr.Dataset() with pytest.raises(ValueError, match="has no name"): @@ -292,21 +285,21 @@ def test_setitem_dataset_on_this_node(self): data = xr.Dataset({"temp": [0, 50]}) results = DataTree(name="results") results["."] = data - xrt.assert_identical(results.ds, data) + xrt.assert_identical(results.to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node(self): data = xr.Dataset({"temp": [0, 50]}) folder1 = DataTree(name="folder1") folder1["results"] = data - xrt.assert_identical(folder1["results"].ds, data) + xrt.assert_identical(folder1["results"].to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node_requiring_intermediate_nodes(self): data = xr.Dataset({"temp": [0, 50]}) folder1 = DataTree(name="folder1") folder1["results/highres"] = data - xrt.assert_identical(folder1["results/highres"].ds, data) + xrt.assert_identical(folder1["results/highres"].to_dataset(), data) def test_setitem_named_dataarray(self): da = xr.DataArray(name="temp", data=[0, 50]) @@ -341,7 +334,7 @@ def test_setitem_dataarray_replace_existing_node(self): p = xr.DataArray(data=[2, 3]) results["pressure"] = p expected = t.assign(pressure=p) - xrt.assert_identical(results.ds, expected) + xrt.assert_identical(results.to_dataset(), expected) class TestDictionaryInterface: @@ -355,16 +348,16 @@ def test_data_in_root(self): assert dt.name is None assert dt.parent is None assert dt.children == {} - xrt.assert_identical(dt.ds, dat) + xrt.assert_identical(dt.to_dataset(), dat) def test_one_layer(self): dat1, dat2 = xr.Dataset({"a": 1}), xr.Dataset({"b": 2}) dt = DataTree.from_dict({"run1": dat1, "run2": dat2}) - xrt.assert_identical(dt.ds, xr.Dataset()) + xrt.assert_identical(dt.to_dataset(), xr.Dataset()) assert dt.name is None - xrt.assert_identical(dt["run1"].ds, dat1) + xrt.assert_identical(dt["run1"].to_dataset(), dat1) assert dt["run1"].children == {} - xrt.assert_identical(dt["run2"].ds, dat2) + xrt.assert_identical(dt["run2"].to_dataset(), dat2) assert dt["run2"].children == {} def test_two_layers(self): @@ -373,13 +366,13 @@ def test_two_layers(self): assert "highres" in dt.children assert "lowres" in dt.children highres_run = dt["highres/run"] - xrt.assert_identical(highres_run.ds, dat1) + xrt.assert_identical(highres_run.to_dataset(), dat1) def test_nones(self): dt = DataTree.from_dict({"d": None, "d/e": None}) assert [node.name for node in dt.subtree] == [None, "d", "e"] assert [node.path for node in dt.subtree] == ["/", "/d", "/d/e"] - xrt.assert_identical(dt["d/e"].ds, xr.Dataset()) + xrt.assert_identical(dt["d/e"].to_dataset(), xr.Dataset()) def test_full(self, simple_datatree): dt = simple_datatree @@ -409,8 +402,44 @@ def test_roundtrip_unnamed_root(self, simple_datatree): assert roundtrip.equals(dt) -class TestBrowsing: - ... +class TestDatasetView: + def test_view_contents(self): + ds = create_test_data() + dt = DataTree(data=ds) + assert ds.identical( + dt.ds + ) # this only works because Dataset.identical doesn't check types + assert isinstance(dt.ds, xr.Dataset) + + def test_immutability(self): + # See issue https://github.com/xarray-contrib/datatree/issues/38 + dt = DataTree(name="root", data=None) + DataTree(name="a", data=None, parent=dt) + + with pytest.raises( + AttributeError, match="Mutation of the DatasetView is not allowed" + ): + dt.ds["a"] = xr.DataArray(0) + + with pytest.raises( + AttributeError, match="Mutation of the DatasetView is not allowed" + ): + dt.ds.update({"a": 0}) + + # TODO are there any other ways you can normally modify state (in-place)? + # (not attribute-like assignment because that doesn't work on Dataset anyway) + + def test_methods(self): + ds = create_test_data() + dt = DataTree(data=ds) + assert ds.mean().identical(dt.ds.mean()) + assert type(dt.ds.mean()) == xr.Dataset + + def test_arithmetic(self, create_test_datatree): + dt = create_test_datatree() + expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"] + result = 10.0 * dt["set1"].ds + assert result.identical(expected) class TestRestructuring: diff --git a/docs/source/whats-new.rst b/docs/source/whats-new.rst index e64ff54..8a31940 100644 --- a/docs/source/whats-new.rst +++ b/docs/source/whats-new.rst @@ -29,12 +29,20 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ +- The ``DataTree.ds`` attribute now returns a view onto an immutable Dataset-like object, instead of an actual instance + of ``xarray.Dataset``. This make break existing ``isinstance`` checks or ``assert`` comparisons. (:pull:`99`) + By `Tom Nicholas `_. + Deprecations ~~~~~~~~~~~~ Bug fixes ~~~~~~~~~ +- Modifying the contents of a ``DataTree`` object via the ``DataTree.ds`` attribute is now forbidden, which prevents + any possibility of the contents of a ``DataTree`` object and its ``.ds`` attribute diverging. (:issue:`38`, :pull:`99`) + By `Tom Nicholas `_. + Documentation ~~~~~~~~~~~~~