From 8429d075615b871bf03dfc819312e6e8ea5d9223 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 29 Jun 2024 22:12:55 +0200 Subject: [PATCH 01/10] Add copy option in from_dict --- xarray/core/datatree.py | 13 +++++++++---- xarray/core/treenode.py | 8 +++++--- xarray/tests/test_datatree.py | 7 +++++++ 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index c923ca2eb87..1b3fd488a09 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -882,7 +882,9 @@ def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: else: raise ValueError(f"Invalid format for key: {key}") - def _set(self, key: str, val: DataTree | CoercibleValue) -> None: + def _set( + self, key: str, val: DataTree | CoercibleValue, *, copy: bool = True + ) -> None: """ Set the child node or variable with the specified key to value. @@ -890,7 +892,7 @@ def _set(self, key: str, val: DataTree | CoercibleValue) -> None: """ if isinstance(val, DataTree): # create and assign a shallow copy here so as not to alter original name of node in grafted tree - new_node = val.copy(deep=False) + new_node = val.copy(deep=False) if copy else val new_node.name = key new_node.parent = self else: @@ -1052,6 +1054,8 @@ def from_dict( cls, d: MutableMapping[str, Dataset | DataArray | DataTree | None], name: str | None = None, + *, + copy: bool = True, ) -> DataTree: """ Create a datatree from a dictionary of data objects, organised by paths into the tree. @@ -1080,7 +1084,7 @@ def from_dict( # First create the root node root_data = d.pop("/", None) if isinstance(root_data, DataTree): - obj = root_data.copy() + obj = root_data.copy() if copy else root_data obj.orphan() else: obj = cls(name=name, data=root_data, parent=None, children=None) @@ -1091,7 +1095,7 @@ def from_dict( # Create and set new node node_name = NodePath(path).name if isinstance(data, DataTree): - new_node = data.copy() + new_node = data.copy() if copy else data new_node.orphan() else: new_node = cls(name=node_name, data=data) @@ -1100,6 +1104,7 @@ def from_dict( new_node, allow_overwrite=False, new_nodes_along_path=True, + copy=copy, ) return obj diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 6f51e1ffa38..d4950940992 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -483,6 +483,8 @@ def _set_item( item: Tree | T_DataArray, new_nodes_along_path: bool = False, allow_overwrite: bool = True, + *, + copy: bool = False, ) -> None: """ Set a new item in the tree, overwriting anything already present at that path. @@ -539,7 +541,7 @@ def _set_item( elif new_nodes_along_path: # Want child classes (i.e. DataTree) to populate tree with their own types new_node = type(self)() - current_node._set(part, new_node) + current_node._set(part, new_node, copy=copy) current_node = current_node.children[part] else: raise KeyError(f"Could not reach node at path {path}") @@ -547,11 +549,11 @@ def _set_item( if name in current_node.children: # Deal with anything already existing at this location if allow_overwrite: - current_node._set(name, item) + current_node._set(name, item, copy=copy) else: raise KeyError(f"Already a node object at path {path}") else: - current_node._set(name, item) + current_node._set(name, item, copy=copy) def __delitem__(self: Tree, key: str): """Remove a child node from this tree object.""" diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index b0dc2accd3e..6c247bde308 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -525,6 +525,13 @@ def test_roundtrip_unnamed_root(self, simple_datatree): roundtrip = DataTree.from_dict(dt.to_dict()) assert roundtrip.equals(dt) + @pytest.mark.parametrize("copy", [True, False]) + def test_copy(self, copy: bool) -> None: + run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})}) + dt = DataTree.from_dict({"run1": run1}, copy=copy) + is_exact = dt["run1"] is run1 + assert is_exact is not copy + class TestDatasetView: def test_view_contents(self): From 6c83fdb004532bcf9f390c51508d4fcb410c5800 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 29 Jun 2024 22:22:16 +0200 Subject: [PATCH 02/10] Update treenode.py --- xarray/core/treenode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index d4950940992..9b5a363f661 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -468,7 +468,7 @@ def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray: current_node = current_node.get(part) return current_node - def _set(self: Tree, key: str, val: Tree) -> None: + def _set(self: Tree, key: str, val: Tree, *, copy: bool = True) -> None: """ Set the child node with the specified key to value. @@ -484,7 +484,7 @@ def _set_item( new_nodes_along_path: bool = False, allow_overwrite: bool = True, *, - copy: bool = False, + copy: bool = True, ) -> None: """ Set a new item in the tree, overwriting anything already present at that path. From 91b54e6947ffa3a61bd8a4c4e4ecfed5c62a0eca Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 29 Jun 2024 23:14:53 +0200 Subject: [PATCH 03/10] Add ASV benchmark --- asv_bench/benchmarks/datatree.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 asv_bench/benchmarks/datatree.py diff --git a/asv_bench/benchmarks/datatree.py b/asv_bench/benchmarks/datatree.py new file mode 100644 index 00000000000..2cde3ca3cc2 --- /dev/null +++ b/asv_bench/benchmarks/datatree.py @@ -0,0 +1,17 @@ +import numpy as np + +import xarray as xr +from xarray.core.datatree import DataTree + +from . import parameterized + + +class Datatree: + def setup(self): + dat1 = xr.Dataset({"a": 1}) + run1 = DataTree.from_dict({"run1": dat1}) + self.d = {"run1": run1} + + @parameterized(["copy"], [(True, False)]) + def time_from_dict(self, copy: bool): + DataTree.from_dict(self.d, copy=copy) From 04b5293212b24f3ab0630fddaee333f133201f35 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 29 Jun 2024 21:15:34 +0000 Subject: [PATCH 04/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- asv_bench/benchmarks/datatree.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/asv_bench/benchmarks/datatree.py b/asv_bench/benchmarks/datatree.py index 2cde3ca3cc2..ef41254ae0d 100644 --- a/asv_bench/benchmarks/datatree.py +++ b/asv_bench/benchmarks/datatree.py @@ -1,5 +1,3 @@ -import numpy as np - import xarray as xr from xarray.core.datatree import DataTree From 53fe271f854379c17a9640ba0d35322bd3d1aa19 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 29 Jun 2024 23:20:45 +0200 Subject: [PATCH 05/10] Update datatree.py --- asv_bench/benchmarks/datatree.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/asv_bench/benchmarks/datatree.py b/asv_bench/benchmarks/datatree.py index 2cde3ca3cc2..302a6e2fcce 100644 --- a/asv_bench/benchmarks/datatree.py +++ b/asv_bench/benchmarks/datatree.py @@ -8,8 +8,7 @@ class Datatree: def setup(self): - dat1 = xr.Dataset({"a": 1}) - run1 = DataTree.from_dict({"run1": dat1}) + run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})}) self.d = {"run1": run1} @parameterized(["copy"], [(True, False)]) From cae1b1c84bea3f2717e46e1709470ec2eeb7fdb9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 29 Jun 2024 23:33:42 +0200 Subject: [PATCH 06/10] Update whats-new.rst --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c58f73cb1fa..821d0cf43e0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,9 @@ v2024.06.1 (unreleased) New Features ~~~~~~~~~~~~ +- Add optional parameter `copy` to :py:func:`xarray.core.datatree.DataTree.from_dict`, + which by default shallow copies the values in the dict. + By `Jimmy Westling `_. - Allow chunking for arrays with duplicated dimension names (:issue:`8759`, :pull:`9099`). By `Martin Raspaud `_. From ae930a5f4db5ef5d79018b8bcbeb316ea0399816 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 29 Jun 2024 23:34:30 +0200 Subject: [PATCH 07/10] Update whats-new.rst --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 821d0cf43e0..2d0667d02e7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,7 +23,7 @@ v2024.06.1 (unreleased) New Features ~~~~~~~~~~~~ - Add optional parameter `copy` to :py:func:`xarray.core.datatree.DataTree.from_dict`, - which by default shallow copies the values in the dict. + which by default shallow copies the values in the dict. (:pull:`9193`) By `Jimmy Westling `_. - Allow chunking for arrays with duplicated dimension names (:issue:`8759`, :pull:`9099`). By `Martin Raspaud `_. From 866a4b3da2c34180d574a7e60073739f6900c3bf Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 29 Jun 2024 23:49:39 +0200 Subject: [PATCH 08/10] Update docstrings --- xarray/core/datatree.py | 2 ++ xarray/core/treenode.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 1b3fd488a09..9d13e3be87f 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1071,6 +1071,8 @@ def from_dict( To assign data to the root node of the tree use "/" as the path. name : Hashable | None, optional Name for the root node of the tree. Default is None. + copy : bool, optional + Whether to make a shallow copy of the values in the dict. Default is True. Returns ------- diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 9b5a363f661..23d2519a84d 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -502,6 +502,8 @@ def _set_item( allow_overwrite : bool Whether or not to overwrite any existing node at the location given by path. + copy : bool, optional + Whether to make a shallow copy of the values in the node. Default is True. Raises ------ From a40b42286972d6863472962b7f8e1490768b7d1a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 12 Jul 2024 00:36:15 +0200 Subject: [PATCH 09/10] Rename copy to fastpath and invert checks. --- asv_bench/benchmarks/datatree.py | 6 +++--- doc/whats-new.rst | 2 +- xarray/core/datatree.py | 16 ++++++++-------- xarray/core/treenode.py | 14 +++++++------- xarray/tests/test_datatree.py | 8 ++++---- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/asv_bench/benchmarks/datatree.py b/asv_bench/benchmarks/datatree.py index 4eaeaa3ad9b..3db0c708225 100644 --- a/asv_bench/benchmarks/datatree.py +++ b/asv_bench/benchmarks/datatree.py @@ -9,6 +9,6 @@ def setup(self): run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})}) self.d = {"run1": run1} - @parameterized(["copy"], [(True, False)]) - def time_from_dict(self, copy: bool): - DataTree.from_dict(self.d, copy=copy) + @parameterized(["fastpath"], [(False, True)]) + def time_from_dict(self, fastpath: bool): + DataTree.from_dict(self.d, fastpath=fastpath) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 2d0667d02e7..b0324ba151b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,7 +22,7 @@ v2024.06.1 (unreleased) New Features ~~~~~~~~~~~~ -- Add optional parameter `copy` to :py:func:`xarray.core.datatree.DataTree.from_dict`, +- Add optional parameter `fastpath` to :py:func:`xarray.core.datatree.DataTree.from_dict`, which by default shallow copies the values in the dict. (:pull:`9193`) By `Jimmy Westling `_. - Allow chunking for arrays with duplicated dimension names (:issue:`8759`, :pull:`9099`). diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 9d13e3be87f..ec1716f0783 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -883,7 +883,7 @@ def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: raise ValueError(f"Invalid format for key: {key}") def _set( - self, key: str, val: DataTree | CoercibleValue, *, copy: bool = True + self, key: str, val: DataTree | CoercibleValue, *, fastpath: bool = False ) -> None: """ Set the child node or variable with the specified key to value. @@ -892,7 +892,7 @@ def _set( """ if isinstance(val, DataTree): # create and assign a shallow copy here so as not to alter original name of node in grafted tree - new_node = val.copy(deep=False) if copy else val + new_node = val.copy(deep=False) if not fastpath else val new_node.name = key new_node.parent = self else: @@ -1055,7 +1055,7 @@ def from_dict( d: MutableMapping[str, Dataset | DataArray | DataTree | None], name: str | None = None, *, - copy: bool = True, + fastpath: bool = False, ) -> DataTree: """ Create a datatree from a dictionary of data objects, organised by paths into the tree. @@ -1071,8 +1071,8 @@ def from_dict( To assign data to the root node of the tree use "/" as the path. name : Hashable | None, optional Name for the root node of the tree. Default is None. - copy : bool, optional - Whether to make a shallow copy of the values in the dict. Default is True. + fastpath : bool, optional + Whether to bypass checks and avoid shallow copies of the values in the dict. Default is False. Returns ------- @@ -1086,7 +1086,7 @@ def from_dict( # First create the root node root_data = d.pop("/", None) if isinstance(root_data, DataTree): - obj = root_data.copy() if copy else root_data + obj = root_data.copy() if not fastpath else root_data obj.orphan() else: obj = cls(name=name, data=root_data, parent=None, children=None) @@ -1097,7 +1097,7 @@ def from_dict( # Create and set new node node_name = NodePath(path).name if isinstance(data, DataTree): - new_node = data.copy() if copy else data + new_node = data.copy() if not fastpath else data new_node.orphan() else: new_node = cls(name=node_name, data=data) @@ -1106,7 +1106,7 @@ def from_dict( new_node, allow_overwrite=False, new_nodes_along_path=True, - copy=copy, + fastpath=fastpath, ) return obj diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 23d2519a84d..570dc9d934c 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -468,7 +468,7 @@ def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray: current_node = current_node.get(part) return current_node - def _set(self: Tree, key: str, val: Tree, *, copy: bool = True) -> None: + def _set(self: Tree, key: str, val: Tree, *, fastpath: bool = False) -> None: """ Set the child node with the specified key to value. @@ -484,7 +484,7 @@ def _set_item( new_nodes_along_path: bool = False, allow_overwrite: bool = True, *, - copy: bool = True, + fastpath: bool = False, ) -> None: """ Set a new item in the tree, overwriting anything already present at that path. @@ -502,8 +502,8 @@ def _set_item( allow_overwrite : bool Whether or not to overwrite any existing node at the location given by path. - copy : bool, optional - Whether to make a shallow copy of the values in the node. Default is True. + fastpath : bool, optional + Whether to bypass checks and avoid shallow copies of the values in the dict. Default is False. Raises ------ @@ -543,7 +543,7 @@ def _set_item( elif new_nodes_along_path: # Want child classes (i.e. DataTree) to populate tree with their own types new_node = type(self)() - current_node._set(part, new_node, copy=copy) + current_node._set(part, new_node, fastpath=fastpath) current_node = current_node.children[part] else: raise KeyError(f"Could not reach node at path {path}") @@ -551,11 +551,11 @@ def _set_item( if name in current_node.children: # Deal with anything already existing at this location if allow_overwrite: - current_node._set(name, item, copy=copy) + current_node._set(name, item, fastpath=fastpath) else: raise KeyError(f"Already a node object at path {path}") else: - current_node._set(name, item, copy=copy) + current_node._set(name, item, fastpath=fastpath) def __delitem__(self: Tree, key: str): """Remove a child node from this tree object.""" diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 6c247bde308..cf26318d557 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -525,12 +525,12 @@ def test_roundtrip_unnamed_root(self, simple_datatree): roundtrip = DataTree.from_dict(dt.to_dict()) assert roundtrip.equals(dt) - @pytest.mark.parametrize("copy", [True, False]) - def test_copy(self, copy: bool) -> None: + @pytest.mark.parametrize("copy", [False, True]) + def test_fastpath(self, fastpath: bool) -> None: run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})}) - dt = DataTree.from_dict({"run1": run1}, copy=copy) + dt = DataTree.from_dict({"run1": run1}, fastpath=fastpath) is_exact = dt["run1"] is run1 - assert is_exact is not copy + assert is_exact is fastpath class TestDatasetView: From 7604934ff479af547e9da696664c38e878038030 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 12 Jul 2024 00:45:16 +0200 Subject: [PATCH 10/10] Update xarray/tests/test_datatree.py --- xarray/tests/test_datatree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index cf26318d557..a8adebe3d04 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -525,7 +525,7 @@ def test_roundtrip_unnamed_root(self, simple_datatree): roundtrip = DataTree.from_dict(dt.to_dict()) assert roundtrip.equals(dt) - @pytest.mark.parametrize("copy", [False, True]) + @pytest.mark.parametrize("fastpath", [False, True]) def test_fastpath(self, fastpath: bool) -> None: run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})}) dt = DataTree.from_dict({"run1": run1}, fastpath=fastpath)