Skip to content
This repository has been archived by the owner on Oct 24, 2024. It is now read-only.

Commit

Permalink
[WIP] add DataTree.to_netcdf (#26)
Browse files Browse the repository at this point in the history
* first attempt at to_netcdf

* lint

* add test for roundtrip and support empty nodes

* Apply suggestions from code review

Co-authored-by: Tom Nicholas <[email protected]>

* update roundtrip test, improves empty node handling in IO

Co-authored-by: Tom Nicholas <[email protected]>
  • Loading branch information
Joe Hamman and TomNicholas authored Aug 25, 2021
1 parent 35fa1e3 commit 84f472a
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 7 deletions.
38 changes: 36 additions & 2 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,10 +873,44 @@ def groups(self):
"""Return all netCDF4 groups in the tree, given as a tuple of path-like strings."""
return tuple(node.pathstr for node in self.subtree_nodes)

def to_netcdf(self, filename: str):
def to_netcdf(
self, filepath, mode: str = "w", encoding=None, unlimited_dims=None, **kwargs
):
"""
Write datatree contents to a netCDF file.
Paramters
---------
filepath : str or Path
Path to which to save this datatree.
mode : {"w", "a"}, default: "w"
Write ('w') or append ('a') mode. If mode='w', any existing file at
this location will be overwritten. If mode='a', existing variables
will be overwritten. Only appies to the root group.
encoding : dict, optional
Nested dictionary with variable names as keys and dictionaries of
variable specific encodings as values, e.g.,
``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1,
"zlib": True}, ...}, ...}``. See ``xarray.Dataset.to_netcdf`` for available
options.
unlimited_dims : dict, optional
Mapping of unlimited dimensions per group that that should be serialized as unlimited dimensions.
By default, no dimensions are treated as unlimited dimensions.
Note that unlimited_dims may also be set via
``dataset.encoding["unlimited_dims"]``.
kwargs :
Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf``
"""
from .io import _datatree_to_netcdf

_datatree_to_netcdf(self, filename)
_datatree_to_netcdf(
self,
filepath,
mode=mode,
encoding=encoding,
unlimited_dims=unlimited_dims,
**kwargs,
)

def plot(self):
raise NotImplementedError
Expand Down
80 changes: 78 additions & 2 deletions datatree/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@
from .datatree import DataNode, DataTree, PathType


def _ds_or_none(ds):
"""return none if ds is empty"""
if any(ds.coords) or any(ds.variables) or any(ds.attrs):
return ds
return None


def _open_group_children_recursively(filename, node, ncgroup, chunks, **kwargs):
for g in ncgroup.groups.values():

# Open and add this node's dataset to the tree
name = os.path.basename(g.path)
ds = open_dataset(filename, group=g.path, chunks=chunks, **kwargs)
ds = _ds_or_none(ds)
child_node = DataNode(name, ds)
node.add_child(child_node)

Expand Down Expand Up @@ -65,5 +73,73 @@ def open_mfdatatree(
return full_tree


def _datatree_to_netcdf(dt: DataTree, filepath: str):
raise NotImplementedError
def _maybe_extract_group_kwargs(enc, group):
try:
return enc[group]
except KeyError:
return None


def _create_empty_group(filename, group, mode):
with netCDF4.Dataset(filename, mode=mode) as rootgrp:
rootgrp.createGroup(group)


def _datatree_to_netcdf(
dt: DataTree,
filepath,
mode: str = "w",
encoding=None,
unlimited_dims=None,
**kwargs
):

if kwargs.get("format", None) not in [None, "NETCDF4"]:
raise ValueError("to_netcdf only supports the NETCDF4 format")

if kwargs.get("engine", None) not in [None, "netcdf4", "h5netcdf"]:
raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines")

if kwargs.get("group", None) is not None:
raise NotImplementedError(
"specifying a root group for the tree has not been implemented"
)

if not kwargs.get("compute", True):
raise NotImplementedError("compute=False has not been implemented yet")

if encoding is None:
encoding = {}

if unlimited_dims is None:
unlimited_dims = {}

ds = dt.ds
group_path = dt.pathstr.replace(dt.root.pathstr, "")
if ds is None:
_create_empty_group(filepath, group_path, mode)
else:
ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr),
**kwargs
)
mode = "a"

for node in dt.descendants:
ds = node.ds
group_path = node.pathstr.replace(dt.root.pathstr, "")
if ds is None:
_create_empty_group(filepath, group_path, mode)
else:
ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr),
**kwargs
)
24 changes: 21 additions & 3 deletions datatree/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from xarray.testing import assert_identical

from datatree import DataNode, DataTree
from datatree.io import open_datatree


def create_test_datatree():
Expand Down Expand Up @@ -31,14 +32,14 @@ def create_test_datatree():
| Dimensions: (x: 2, y: 3)
| Data variables:
| a (y) int64 6, 7, 8
| set1 (x) int64 9, 10
| set0 (x) int64 9, 10
The structure has deliberately repeated names of tags, variables, and
dimensions in order to better check for bugs caused by name conflicts.
"""
set1_data = xr.Dataset({"a": 0, "b": 1})
set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set1": ("x", [9, 10])})
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})

# Avoid using __init__ so we can independently test it
# TODO change so it has a DataTree at the bottom
Expand Down Expand Up @@ -297,4 +298,21 @@ def test_repr_of_node_with_data(self):


class TestIO:
...
def test_to_netcdf(self, tmpdir):
filepath = str(
tmpdir / "test.nc"
) # casting to str avoids a pathlib bug in xarray
original_dt = create_test_datatree()
original_dt.to_netcdf(filepath, engine="netcdf4")

roundtrip_dt = open_datatree(filepath)

original_dt.name == roundtrip_dt.name
assert original_dt.ds.identical(roundtrip_dt.ds)
for a, b in zip(original_dt.descendants, roundtrip_dt.descendants):
assert a.name == b.name
assert a.pathstr == b.pathstr
if a.has_data:
assert a.ds.identical(b.ds)
else:
assert a.ds is b.ds

0 comments on commit 84f472a

Please sign in to comment.