From 700f5d663a777b6ee67c50900632662cc7aeea18 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 25 Aug 2021 11:26:52 -0700 Subject: [PATCH 1/5] first attempt at to_netcdf --- datatree/datatree.py | 35 +++++++++++++++++++-- datatree/io.py | 55 +++++++++++++++++++++++++++++++-- datatree/tests/test_datatree.py | 10 ++++-- 3 files changed, 93 insertions(+), 7 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index a3df42d1..174b94fb 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -874,10 +874,41 @@ def groups(self): """Return all netCDF4 groups in the tree, given as a tuple of path-like strings.""" return tuple(node.pathstr for node in self.subtree_nodes) - def to_netcdf(self, filename: str): + def to_netcdf(self, + filepath, + mode: str='w', + encoding=None, + unlimited_dims=None, + **kwargs + ): + ''' + Write datatree contents to a netCDF file. + + Paramters + --------- + filepath : str or Path + Path to which to save this datatree. + mode : {"w", "a"}, default: "w" + Write ('w') or append ('a') mode. If mode='w', any existing file at + this location will be overwritten. If mode='a', existing variables + will be overwritten. Only appies to the root group. + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1, + "zlib": True}, ...}, ...}``. See ``xarray.Dataset.to_netcdf`` for available + options. + unlimited_dims : dict, optional + Mapping of unlimited dimensions per group that that should be serialized as unlimited dimensions. + By default, no dimensions are treated as unlimited dimensions. + Note that unlimited_dims may also be set via + ``dataset.encoding["unlimited_dims"]``. + kwargs : + Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` + ''' from .io import _datatree_to_netcdf - _datatree_to_netcdf(self, filename) + _datatree_to_netcdf(self, filepath, mode=mode, encoding=encoding, unlimited_dims=unlimited_dims, **kwargs) def plot(self): raise NotImplementedError diff --git a/datatree/io.py b/datatree/io.py index 727b4485..11db2cc6 100644 --- a/datatree/io.py +++ b/datatree/io.py @@ -65,5 +65,56 @@ def open_mfdatatree( return full_tree -def _datatree_to_netcdf(dt: DataTree, filepath: str): - raise NotImplementedError +def _maybe_extract_group_kwargs(enc, group): + try: + return enc[group] + except KeyError: + return None + + +def _datatree_to_netcdf( + dt: DataTree, + filepath, + mode: str='w', + encoding=None, + unlimited_dims=None, + **kwargs +): + + if kwargs.get('format', None) not in [None, 'NETCDF4']: + raise ValueError('to_netcdf only supports the NETCDF4 format') + + if kwargs.get('engine', None) not in [None, 'netcdf4', 'h5netcdf']: + raise ValueError('to_netcdf only supports the netcdf4 and h5netcdf engines') + + if not kwargs.get('compute', True): + raise NotImplementedError('compute=False has not been implemented yet') + + if encoding is None: + encoding = {} + + if unlimited_dims is None: + unlimited_dims = {} + + if dt.has_data: + dt.ds.to_netcdf( + filepath, + group=dt.pathstr, + mode=mode, + encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), + unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), + **kwargs + ) + mode = 'a' + + for node in dt.subtree_nodes: + if node.has_data: + node.ds.to_netcdf( + filepath, + group=node.pathstr, + mode=mode, + encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), + unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), + **kwargs + ) + mode = 'a' diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index a1266af0..313d6abf 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -31,14 +31,14 @@ def create_test_datatree(): | Dimensions: (x: 2, y: 3) | Data variables: | a (y) int64 6, 7, 8 - | set1 (x) int64 9, 10 + | set0 (x) int64 9, 10 The structure has deliberately repeated names of tags, variables, and dimensions in order to better check for bugs caused by name conflicts. """ set1_data = xr.Dataset({"a": 0, "b": 1}) set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}) - root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set1": ("x", [9, 10])}) + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) # Avoid using __init__ so we can independently test it # TODO change so it has a DataTree at the bottom @@ -297,4 +297,8 @@ def test_repr_of_node_with_data(self): class TestIO: - ... + + def test_to_netcdf(self, tmpdir): + filepath = str(tmpdir / 'test.nc') # casting to str avoids a pathlib bug in xarray + dt = create_test_datatree() + dt.to_netcdf(filepath, engine='netcdf4') From 7052e20aeab84d5715d29c4d72bf0f1cd942f8d9 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 25 Aug 2021 11:27:26 -0700 Subject: [PATCH 2/5] lint --- datatree/datatree.py | 25 ++++++++++++++----------- datatree/io.py | 30 +++++++++++++++--------------- datatree/tests/test_datatree.py | 7 ++++--- 3 files changed, 33 insertions(+), 29 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index 174b94fb..ae516fd3 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -874,16 +874,12 @@ def groups(self): """Return all netCDF4 groups in the tree, given as a tuple of path-like strings.""" return tuple(node.pathstr for node in self.subtree_nodes) - def to_netcdf(self, - filepath, - mode: str='w', - encoding=None, - unlimited_dims=None, - **kwargs - ): - ''' + def to_netcdf( + self, filepath, mode: str = "w", encoding=None, unlimited_dims=None, **kwargs + ): + """ Write datatree contents to a netCDF file. - + Paramters --------- filepath : str or Path @@ -905,10 +901,17 @@ def to_netcdf(self, ``dataset.encoding["unlimited_dims"]``. kwargs : Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` - ''' + """ from .io import _datatree_to_netcdf - _datatree_to_netcdf(self, filepath, mode=mode, encoding=encoding, unlimited_dims=unlimited_dims, **kwargs) + _datatree_to_netcdf( + self, + filepath, + mode=mode, + encoding=encoding, + unlimited_dims=unlimited_dims, + **kwargs, + ) def plot(self): raise NotImplementedError diff --git a/datatree/io.py b/datatree/io.py index 11db2cc6..0b0992cf 100644 --- a/datatree/io.py +++ b/datatree/io.py @@ -75,27 +75,27 @@ def _maybe_extract_group_kwargs(enc, group): def _datatree_to_netcdf( dt: DataTree, filepath, - mode: str='w', + mode: str = "w", encoding=None, unlimited_dims=None, **kwargs ): - if kwargs.get('format', None) not in [None, 'NETCDF4']: - raise ValueError('to_netcdf only supports the NETCDF4 format') - - if kwargs.get('engine', None) not in [None, 'netcdf4', 'h5netcdf']: - raise ValueError('to_netcdf only supports the netcdf4 and h5netcdf engines') - - if not kwargs.get('compute', True): - raise NotImplementedError('compute=False has not been implemented yet') - + if kwargs.get("format", None) not in [None, "NETCDF4"]: + raise ValueError("to_netcdf only supports the NETCDF4 format") + + if kwargs.get("engine", None) not in [None, "netcdf4", "h5netcdf"]: + raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines") + + if not kwargs.get("compute", True): + raise NotImplementedError("compute=False has not been implemented yet") + if encoding is None: encoding = {} - + if unlimited_dims is None: unlimited_dims = {} - + if dt.has_data: dt.ds.to_netcdf( filepath, @@ -105,8 +105,8 @@ def _datatree_to_netcdf( unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), **kwargs ) - mode = 'a' - + mode = "a" + for node in dt.subtree_nodes: if node.has_data: node.ds.to_netcdf( @@ -117,4 +117,4 @@ def _datatree_to_netcdf( unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), **kwargs ) - mode = 'a' + mode = "a" diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index 313d6abf..7ecb5ef1 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -297,8 +297,9 @@ def test_repr_of_node_with_data(self): class TestIO: - def test_to_netcdf(self, tmpdir): - filepath = str(tmpdir / 'test.nc') # casting to str avoids a pathlib bug in xarray + filepath = str( + tmpdir / "test.nc" + ) # casting to str avoids a pathlib bug in xarray dt = create_test_datatree() - dt.to_netcdf(filepath, engine='netcdf4') + dt.to_netcdf(filepath, engine="netcdf4") From f36b8a2d23be9d44df7b872428936680f14c8bb2 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 25 Aug 2021 12:46:32 -0700 Subject: [PATCH 3/5] add test for roundtrip and support empty nodes --- datatree/io.py | 37 ++++++++++++++++++--------------- datatree/tests/test_datatree.py | 15 +++++++++++-- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/datatree/io.py b/datatree/io.py index 0b0992cf..2323f2b1 100644 --- a/datatree/io.py +++ b/datatree/io.py @@ -2,7 +2,7 @@ from typing import Dict, Sequence import netCDF4 -from xarray import open_dataset +from xarray import Dataset, open_dataset from .datatree import DataNode, DataTree, PathType @@ -96,25 +96,28 @@ def _datatree_to_netcdf( if unlimited_dims is None: unlimited_dims = {} - if dt.has_data: - dt.ds.to_netcdf( + ds = dt.ds + if ds is None: + ds = Dataset() + ds.to_netcdf( + filepath, + group=dt.pathstr, + mode=mode, + encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), + unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), + **kwargs + ) + mode = "a" + + for node in dt.subtree_nodes: + ds = node.ds + if ds is None: + ds = Dataset() + ds.to_netcdf( filepath, - group=dt.pathstr, + group=node.pathstr, mode=mode, encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), **kwargs ) - mode = "a" - - for node in dt.subtree_nodes: - if node.has_data: - node.ds.to_netcdf( - filepath, - group=node.pathstr, - mode=mode, - encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), - unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), - **kwargs - ) - mode = "a" diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index 7ecb5ef1..16c5c615 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -4,6 +4,7 @@ from xarray.testing import assert_identical from datatree import DataNode, DataTree +from datatree.io import open_datatree def create_test_datatree(): @@ -301,5 +302,15 @@ def test_to_netcdf(self, tmpdir): filepath = str( tmpdir / "test.nc" ) # casting to str avoids a pathlib bug in xarray - dt = create_test_datatree() - dt.to_netcdf(filepath, engine="netcdf4") + original_dt = create_test_datatree() + original_dt.to_netcdf(filepath, engine="netcdf4") + + roundtrip_dt = open_datatree(filepath) + + # Q: why does the roundtrip_dt have an extra `root/` in all paths + print([n.pathstr for n in original_dt.subtree_nodes]) + print([n.pathstr for n in roundtrip_dt.subtree_nodes]) + + assert original_dt.ds.identical(original_dt.ds) + for node in original_dt.subtree_nodes: + assert node.ds.identical(original_dt["root/" + node.pathstr]) From 72d23f92ab6f8248ca5b6efc7934b7c491b687af Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Wed, 25 Aug 2021 14:31:48 -0700 Subject: [PATCH 4/5] Apply suggestions from code review Co-authored-by: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> --- datatree/io.py | 8 +++++--- datatree/tests/test_datatree.py | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/datatree/io.py b/datatree/io.py index 2323f2b1..55dd8804 100644 --- a/datatree/io.py +++ b/datatree/io.py @@ -99,9 +99,10 @@ def _datatree_to_netcdf( ds = dt.ds if ds is None: ds = Dataset() + group_path = dt.pathstr.replace(dt.root.pathstr, "") ds.to_netcdf( filepath, - group=dt.pathstr, + group=group_path, mode=mode, encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), @@ -109,13 +110,14 @@ def _datatree_to_netcdf( ) mode = "a" - for node in dt.subtree_nodes: + for node in dt.descendants: ds = node.ds if ds is None: ds = Dataset() + group_path = node.pathstr.replace(dt.root.pathstr, "") ds.to_netcdf( filepath, - group=node.pathstr, + group=group_path, mode=mode, encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index 16c5c615..2d549033 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -311,6 +311,6 @@ def test_to_netcdf(self, tmpdir): print([n.pathstr for n in original_dt.subtree_nodes]) print([n.pathstr for n in roundtrip_dt.subtree_nodes]) - assert original_dt.ds.identical(original_dt.ds) - for node in original_dt.subtree_nodes: - assert node.ds.identical(original_dt["root/" + node.pathstr]) + assert original_dt.ds.identical(roundtrip_dt.ds) + for node in original_dt.descendants: + assert node.ds.identical(roundtrip_dt[node.pathstr]) From a390c9c64c3d4b03f88b1e7a38800e3e234d6790 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 25 Aug 2021 15:07:21 -0700 Subject: [PATCH 5/5] update roundtrip test, improves empty node handling in IO --- datatree/io.py | 56 ++++++++++++++++++++++----------- datatree/tests/test_datatree.py | 14 +++++---- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/datatree/io.py b/datatree/io.py index 55dd8804..c717203a 100644 --- a/datatree/io.py +++ b/datatree/io.py @@ -2,17 +2,25 @@ from typing import Dict, Sequence import netCDF4 -from xarray import Dataset, open_dataset +from xarray import open_dataset from .datatree import DataNode, DataTree, PathType +def _ds_or_none(ds): + """return none if ds is empty""" + if any(ds.coords) or any(ds.variables) or any(ds.attrs): + return ds + return None + + def _open_group_children_recursively(filename, node, ncgroup, chunks, **kwargs): for g in ncgroup.groups.values(): # Open and add this node's dataset to the tree name = os.path.basename(g.path) ds = open_dataset(filename, group=g.path, chunks=chunks, **kwargs) + ds = _ds_or_none(ds) child_node = DataNode(name, ds) node.add_child(child_node) @@ -72,6 +80,11 @@ def _maybe_extract_group_kwargs(enc, group): return None +def _create_empty_group(filename, group, mode): + with netCDF4.Dataset(filename, mode=mode) as rootgrp: + rootgrp.createGroup(group) + + def _datatree_to_netcdf( dt: DataTree, filepath, @@ -87,6 +100,11 @@ def _datatree_to_netcdf( if kwargs.get("engine", None) not in [None, "netcdf4", "h5netcdf"]: raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines") + if kwargs.get("group", None) is not None: + raise NotImplementedError( + "specifying a root group for the tree has not been implemented" + ) + if not kwargs.get("compute", True): raise NotImplementedError("compute=False has not been implemented yet") @@ -97,24 +115,10 @@ def _datatree_to_netcdf( unlimited_dims = {} ds = dt.ds - if ds is None: - ds = Dataset() group_path = dt.pathstr.replace(dt.root.pathstr, "") - ds.to_netcdf( - filepath, - group=group_path, - mode=mode, - encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), - unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), - **kwargs - ) - mode = "a" - - for node in dt.descendants: - ds = node.ds - if ds is None: - ds = Dataset() - group_path = node.pathstr.replace(dt.root.pathstr, "") + if ds is None: + _create_empty_group(filepath, group_path, mode) + else: ds.to_netcdf( filepath, group=group_path, @@ -123,3 +127,19 @@ def _datatree_to_netcdf( unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), **kwargs ) + mode = "a" + + for node in dt.descendants: + ds = node.ds + group_path = node.pathstr.replace(dt.root.pathstr, "") + if ds is None: + _create_empty_group(filepath, group_path, mode) + else: + ds.to_netcdf( + filepath, + group=group_path, + mode=mode, + encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), + unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), + **kwargs + ) diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index 2d549033..91412e6b 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -307,10 +307,12 @@ def test_to_netcdf(self, tmpdir): roundtrip_dt = open_datatree(filepath) - # Q: why does the roundtrip_dt have an extra `root/` in all paths - print([n.pathstr for n in original_dt.subtree_nodes]) - print([n.pathstr for n in roundtrip_dt.subtree_nodes]) - + original_dt.name == roundtrip_dt.name assert original_dt.ds.identical(roundtrip_dt.ds) - for node in original_dt.descendants: - assert node.ds.identical(roundtrip_dt[node.pathstr]) + for a, b in zip(original_dt.descendants, roundtrip_dt.descendants): + assert a.name == b.name + assert a.pathstr == b.pathstr + if a.has_data: + assert a.ds.identical(b.ds) + else: + assert a.ds is b.ds