Skip to content
This repository has been archived by the owner on Oct 24, 2024. It is now read-only.

Commit

Permalink
Replace .ds with immutable DatasetView (#99)
Browse files Browse the repository at this point in the history
* sketching out changes needed to integrate variables into DataTree

* fixed some other basic conflicts

* fix mypy errors

* can create basic datatree node objects again

* child-variable name collisions dectected correctly

* in-progres

* add _replace method

* updated tests to assert identical instead of check .ds is expected_ds

* refactor .ds setter to use _replace

* refactor init to use _replace

* refactor test tree to avoid init

* attempt at copy methods

* rewrote implementation of .copy method

* xfailing test for deepcopying

* pseudocode implementation of DatasetView

* Revert "pseudocode implementation of DatasetView"

This reverts commit 52ef23b.

* pseudocode implementation of DatasetView

* removed duplicated implementation of copy

* reorganise API docs

* expose data_vars, coords etc. properties

* try except with calculate_dimensions private import

* add keys/values/items methods

* don't use has_data when .variables would do

* change asserts to not fail just because of differing types

* full sketch of how DatasetView could work

* added tests for DatasetView

* remove commented pseudocode

* explanation of basic properties

* add data structures page to index

* revert adding documentation in favour of that going in a different PR

* correct deepcopy tests

* use .data_vars in copy tests

* add test for arithmetic with .ds

* remove reference to wrapping node in DatasetView

* clarify type through renaming variables

* remove test for out-of-node access

* make imports depend on most recent version of xarray

Co-authored-by: Mattia Almansi <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove try except for internal import

* depend on latest pre-release of xarray

* correct name of version

* xarray pre-release under pip in ci envs

* correct methods

* whatsnews

* fix fixture in test

* whatsnew

* improve docstrings

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: Mattia Almansi <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 16, 2022
1 parent c2b6400 commit 9a6f674
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 34 deletions.
141 changes: 137 additions & 4 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Set,
Tuple,
Union,
overload,
)

import pandas as pd
Expand Down Expand Up @@ -86,6 +87,135 @@ def _check_for_name_collisions(
)


class DatasetView(Dataset):
"""
An immutable Dataset-like view onto the data in a single DataTree node.
In-place operations modifying this object should raise an AttributeError.
Operations returning a new result will return a new xarray.Dataset object.
This includes all API on Dataset, which will be inherited.
This requires overriding all inherited private constructors.
"""

# TODO what happens if user alters (in-place) a DataArray they extracted from this object?

def __init__(
self,
data_vars: Mapping[Any, Any] = None,
coords: Mapping[Any, Any] = None,
attrs: Mapping[Any, Any] = None,
):
raise AttributeError("DatasetView objects are not to be initialized directly")

@classmethod
def _from_node(
cls,
wrapping_node: DataTree,
) -> DatasetView:
"""Constructor, using dataset attributes from wrapping node"""

obj: DatasetView = object.__new__(cls)
obj._variables = wrapping_node._variables
obj._coord_names = wrapping_node._coord_names
obj._dims = wrapping_node._dims
obj._indexes = wrapping_node._indexes
obj._attrs = wrapping_node._attrs
obj._close = wrapping_node._close
obj._encoding = wrapping_node._encoding

return obj

def __setitem__(self, key, val) -> None:
raise AttributeError(
"Mutation of the DatasetView is not allowed, please use __setitem__ on the wrapping DataTree node, "
"or use `DataTree.to_dataset()` if you want a mutable dataset"
)

def update(self, other) -> None:
raise AttributeError(
"Mutation of the DatasetView is not allowed, please use .update on the wrapping DataTree node, "
"or use `DataTree.to_dataset()` if you want a mutable dataset"
)

# FIXME https://github.com/python/mypy/issues/7328
@overload
def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[misc]
...

@overload
def __getitem__(self, key: Hashable) -> DataArray: # type: ignore[misc]
...

@overload
def __getitem__(self, key: Any) -> Dataset:
...

def __getitem__(self, key) -> DataArray:
# TODO call the `_get_item` method of DataTree to allow path-like access to contents of other nodes
# For now just call Dataset.__getitem__
return Dataset.__getitem__(self, key)

@classmethod
def _construct_direct(
cls,
variables: dict[Any, Variable],
coord_names: set[Hashable],
dims: dict[Any, int] = None,
attrs: dict = None,
indexes: dict[Any, Index] = None,
encoding: dict = None,
close: Callable[[], None] = None,
) -> Dataset:
"""
Overriding this method (along with ._replace) and modifying it to return a Dataset object
should hopefully ensure that the return type of any method on this object is a Dataset.
"""
if dims is None:
dims = calculate_dimensions(variables)
if indexes is None:
indexes = {}
obj = object.__new__(Dataset)
obj._variables = variables
obj._coord_names = coord_names
obj._dims = dims
obj._indexes = indexes
obj._attrs = attrs
obj._close = close
obj._encoding = encoding
return obj

def _replace(
self,
variables: dict[Hashable, Variable] = None,
coord_names: set[Hashable] = None,
dims: dict[Any, int] = None,
attrs: dict[Hashable, Any] | None | Default = _default,
indexes: dict[Hashable, Index] = None,
encoding: dict | None | Default = _default,
inplace: bool = False,
) -> Dataset:
"""
Overriding this method (along with ._construct_direct) and modifying it to return a Dataset object
should hopefully ensure that the return type of any method on this object is a Dataset.
"""

if inplace:
raise AttributeError("In-place mutation of the DatasetView is not allowed")

return Dataset._replace(
self,
variables=variables,
coord_names=coord_names,
dims=dims,
attrs=attrs,
indexes=indexes,
encoding=encoding,
inplace=inplace,
)


class DataTree(
TreeNode,
MappedDatasetMethodsMixin,
Expand Down Expand Up @@ -217,10 +347,13 @@ def parent(self: DataTree, new_parent: DataTree) -> None:
self._set_parent(new_parent, self.name)

@property
def ds(self) -> Dataset:
"""The data in this node, returned as a Dataset."""
# TODO change this to return only an immutable view onto this node's data (see GH #80)
return self.to_dataset()
def ds(self) -> DatasetView:
"""
An immutable Dataset-like view onto the data in this node.
For a mutable Dataset containing the same data as in this node, use `.to_dataset()` instead.
"""
return DatasetView._from_node(self)

@ds.setter
def ds(self, data: Union[Dataset, DataArray] = None) -> None:
Expand Down
6 changes: 3 additions & 3 deletions datatree/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
*args_as_tree_length_iterables,
*list(kwargs_as_tree_length_iterables.values()),
):
node_args_as_datasets = [
node_args_as_datasetviews = [
a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args]
]
node_kwargs_as_datasets = dict(
node_kwargs_as_datasetviews = dict(
zip(
[k for k in kwargs_as_tree_length_iterables.keys()],
[
Expand All @@ -204,7 +204,7 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:

# Now we can call func on the data in this particular set of corresponding nodes
results = (
func(*node_args_as_datasets, **node_kwargs_as_datasets)
func(*node_args_as_datasetviews, **node_kwargs_as_datasetviews)
if not node_of_first_tree.is_empty
else None
)
Expand Down
83 changes: 56 additions & 27 deletions datatree/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import xarray as xr
import xarray.testing as xrt
from xarray.tests import source_ndarray
from xarray.tests import create_test_data, source_ndarray

import datatree.testing as dtt
from datatree import DataTree
Expand All @@ -16,7 +16,7 @@ def test_empty(self):
assert dt.name == "root"
assert dt.parent is None
assert dt.children == {}
xrt.assert_identical(dt.ds, xr.Dataset())
xrt.assert_identical(dt.to_dataset(), xr.Dataset())

def test_unnamed(self):
dt = DataTree()
Expand Down Expand Up @@ -66,7 +66,7 @@ class TestStoreDatasets:
def test_create_with_data(self):
dat = xr.Dataset({"a": 0})
john = DataTree(name="john", data=dat)
xrt.assert_identical(john.ds, dat)
xrt.assert_identical(john.to_dataset(), dat)

with pytest.raises(TypeError):
DataTree(name="mary", parent=john, data="junk") # noqa
Expand All @@ -75,7 +75,7 @@ def test_set_data(self):
john = DataTree(name="john")
dat = xr.Dataset({"a": 0})
john.ds = dat
xrt.assert_identical(john.ds, dat)
xrt.assert_identical(john.to_dataset(), dat)
with pytest.raises(TypeError):
john.ds = "junk"

Expand All @@ -100,16 +100,9 @@ def test_assign_when_already_child_with_variables_name(self):
dt.ds = xr.Dataset({"a": 0})

dt.ds = xr.Dataset()
new_ds = dt.to_dataset().assign(a=xr.DataArray(0))
with pytest.raises(KeyError, match="names would collide"):
dt.ds = dt.ds.assign(a=xr.DataArray(0))

@pytest.mark.xfail
def test_update_when_already_child_with_variables_name(self):
# See issue #38
dt = DataTree(name="root", data=None)
DataTree(name="a", data=None, parent=dt)
with pytest.raises(KeyError, match="names would collide"):
dt.ds["a"] = xr.DataArray(0)
dt.ds = new_ds


class TestGet:
Expand Down Expand Up @@ -275,13 +268,13 @@ def test_setitem_new_empty_node(self):
john["mary"] = DataTree()
mary = john["mary"]
assert isinstance(mary, DataTree)
xrt.assert_identical(mary.ds, xr.Dataset())
xrt.assert_identical(mary.to_dataset(), xr.Dataset())

def test_setitem_overwrite_data_in_node_with_none(self):
john = DataTree(name="john")
mary = DataTree(name="mary", parent=john, data=xr.Dataset())
john["mary"] = DataTree()
xrt.assert_identical(mary.ds, xr.Dataset())
xrt.assert_identical(mary.to_dataset(), xr.Dataset())

john.ds = xr.Dataset()
with pytest.raises(ValueError, match="has no name"):
Expand All @@ -292,21 +285,21 @@ def test_setitem_dataset_on_this_node(self):
data = xr.Dataset({"temp": [0, 50]})
results = DataTree(name="results")
results["."] = data
xrt.assert_identical(results.ds, data)
xrt.assert_identical(results.to_dataset(), data)

@pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes")
def test_setitem_dataset_as_new_node(self):
data = xr.Dataset({"temp": [0, 50]})
folder1 = DataTree(name="folder1")
folder1["results"] = data
xrt.assert_identical(folder1["results"].ds, data)
xrt.assert_identical(folder1["results"].to_dataset(), data)

@pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes")
def test_setitem_dataset_as_new_node_requiring_intermediate_nodes(self):
data = xr.Dataset({"temp": [0, 50]})
folder1 = DataTree(name="folder1")
folder1["results/highres"] = data
xrt.assert_identical(folder1["results/highres"].ds, data)
xrt.assert_identical(folder1["results/highres"].to_dataset(), data)

def test_setitem_named_dataarray(self):
da = xr.DataArray(name="temp", data=[0, 50])
Expand Down Expand Up @@ -341,7 +334,7 @@ def test_setitem_dataarray_replace_existing_node(self):
p = xr.DataArray(data=[2, 3])
results["pressure"] = p
expected = t.assign(pressure=p)
xrt.assert_identical(results.ds, expected)
xrt.assert_identical(results.to_dataset(), expected)


class TestDictionaryInterface:
Expand All @@ -355,16 +348,16 @@ def test_data_in_root(self):
assert dt.name is None
assert dt.parent is None
assert dt.children == {}
xrt.assert_identical(dt.ds, dat)
xrt.assert_identical(dt.to_dataset(), dat)

def test_one_layer(self):
dat1, dat2 = xr.Dataset({"a": 1}), xr.Dataset({"b": 2})
dt = DataTree.from_dict({"run1": dat1, "run2": dat2})
xrt.assert_identical(dt.ds, xr.Dataset())
xrt.assert_identical(dt.to_dataset(), xr.Dataset())
assert dt.name is None
xrt.assert_identical(dt["run1"].ds, dat1)
xrt.assert_identical(dt["run1"].to_dataset(), dat1)
assert dt["run1"].children == {}
xrt.assert_identical(dt["run2"].ds, dat2)
xrt.assert_identical(dt["run2"].to_dataset(), dat2)
assert dt["run2"].children == {}

def test_two_layers(self):
Expand All @@ -373,13 +366,13 @@ def test_two_layers(self):
assert "highres" in dt.children
assert "lowres" in dt.children
highres_run = dt["highres/run"]
xrt.assert_identical(highres_run.ds, dat1)
xrt.assert_identical(highres_run.to_dataset(), dat1)

def test_nones(self):
dt = DataTree.from_dict({"d": None, "d/e": None})
assert [node.name for node in dt.subtree] == [None, "d", "e"]
assert [node.path for node in dt.subtree] == ["/", "/d", "/d/e"]
xrt.assert_identical(dt["d/e"].ds, xr.Dataset())
xrt.assert_identical(dt["d/e"].to_dataset(), xr.Dataset())

def test_full(self, simple_datatree):
dt = simple_datatree
Expand Down Expand Up @@ -409,8 +402,44 @@ def test_roundtrip_unnamed_root(self, simple_datatree):
assert roundtrip.equals(dt)


class TestBrowsing:
...
class TestDatasetView:
def test_view_contents(self):
ds = create_test_data()
dt = DataTree(data=ds)
assert ds.identical(
dt.ds
) # this only works because Dataset.identical doesn't check types
assert isinstance(dt.ds, xr.Dataset)

def test_immutability(self):
# See issue #38
dt = DataTree(name="root", data=None)
DataTree(name="a", data=None, parent=dt)

with pytest.raises(
AttributeError, match="Mutation of the DatasetView is not allowed"
):
dt.ds["a"] = xr.DataArray(0)

with pytest.raises(
AttributeError, match="Mutation of the DatasetView is not allowed"
):
dt.ds.update({"a": 0})

# TODO are there any other ways you can normally modify state (in-place)?
# (not attribute-like assignment because that doesn't work on Dataset anyway)

def test_methods(self):
ds = create_test_data()
dt = DataTree(data=ds)
assert ds.mean().identical(dt.ds.mean())
assert type(dt.ds.mean()) == xr.Dataset

def test_arithmetic(self, create_test_datatree):
dt = create_test_datatree()
expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"]
result = 10.0 * dt["set1"].ds
assert result.identical(expected)


class TestRestructuring:
Expand Down
8 changes: 8 additions & 0 deletions docs/source/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,20 @@ New Features
Breaking changes
~~~~~~~~~~~~~~~~

- The ``DataTree.ds`` attribute now returns a view onto an immutable Dataset-like object, instead of an actual instance
of ``xarray.Dataset``. This make break existing ``isinstance`` checks or ``assert`` comparisons. (:pull:`99`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Deprecations
~~~~~~~~~~~~

Bug fixes
~~~~~~~~~

- Modifying the contents of a ``DataTree`` object via the ``DataTree.ds`` attribute is now forbidden, which prevents
any possibility of the contents of a ``DataTree`` object and its ``.ds`` attribute diverging. (:issue:`38`, :pull:`99`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Documentation
~~~~~~~~~~~~~

Expand Down

0 comments on commit 9a6f674

Please sign in to comment.