Skip to content
This repository has been archived by the owner on Oct 24, 2024. It is now read-only.

[WIP] add DataTree.to_netcdf #26

Merged
merged 5 commits into from
Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,10 +874,44 @@ def groups(self):
"""Return all netCDF4 groups in the tree, given as a tuple of path-like strings."""
return tuple(node.pathstr for node in self.subtree_nodes)

def to_netcdf(self, filename: str):
def to_netcdf(
self, filepath, mode: str = "w", encoding=None, unlimited_dims=None, **kwargs
):
"""
Write datatree contents to a netCDF file.

Paramters
---------
filepath : str or Path
Path to which to save this datatree.
mode : {"w", "a"}, default: "w"
Write ('w') or append ('a') mode. If mode='w', any existing file at
this location will be overwritten. If mode='a', existing variables
will be overwritten. Only appies to the root group.
encoding : dict, optional
Nested dictionary with variable names as keys and dictionaries of
variable specific encodings as values, e.g.,
``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1,
"zlib": True}, ...}, ...}``. See ``xarray.Dataset.to_netcdf`` for available
options.
unlimited_dims : dict, optional
Mapping of unlimited dimensions per group that that should be serialized as unlimited dimensions.
By default, no dimensions are treated as unlimited dimensions.
Note that unlimited_dims may also be set via
``dataset.encoding["unlimited_dims"]``.
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
kwargs :
Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf``
"""
from .io import _datatree_to_netcdf

_datatree_to_netcdf(self, filename)
_datatree_to_netcdf(
self,
filepath,
mode=mode,
encoding=encoding,
unlimited_dims=unlimited_dims,
**kwargs,
)

def plot(self):
raise NotImplementedError
Expand Down
80 changes: 78 additions & 2 deletions datatree/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@
from .datatree import DataNode, DataTree, PathType


def _ds_or_none(ds):
"""return none if ds is empty"""
if any(ds.coords) or any(ds.variables) or any(ds.attrs):
return ds
return None


def _open_group_children_recursively(filename, node, ncgroup, chunks, **kwargs):
for g in ncgroup.groups.values():

# Open and add this node's dataset to the tree
name = os.path.basename(g.path)
ds = open_dataset(filename, group=g.path, chunks=chunks, **kwargs)
ds = _ds_or_none(ds)
child_node = DataNode(name, ds)
node.add_child(child_node)

Expand Down Expand Up @@ -65,5 +73,73 @@ def open_mfdatatree(
return full_tree


def _datatree_to_netcdf(dt: DataTree, filepath: str):
raise NotImplementedError
def _maybe_extract_group_kwargs(enc, group):
try:
return enc[group]
except KeyError:
return None


def _create_empty_group(filename, group, mode):
with netCDF4.Dataset(filename, mode=mode) as rootgrp:
rootgrp.createGroup(group)


def _datatree_to_netcdf(
dt: DataTree,
filepath,
mode: str = "w",
encoding=None,
unlimited_dims=None,
**kwargs
):

if kwargs.get("format", None) not in [None, "NETCDF4"]:
raise ValueError("to_netcdf only supports the NETCDF4 format")

if kwargs.get("engine", None) not in [None, "netcdf4", "h5netcdf"]:
raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines")

if kwargs.get("group", None) is not None:
raise NotImplementedError(
"specifying a root group for the tree has not been implemented"
)

if not kwargs.get("compute", True):
raise NotImplementedError("compute=False has not been implemented yet")

if encoding is None:
encoding = {}

if unlimited_dims is None:
unlimited_dims = {}

ds = dt.ds
group_path = dt.pathstr.replace(dt.root.pathstr, "")
if ds is None:
_create_empty_group(filepath, group_path, mode)
else:
ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr),
**kwargs
)
mode = "a"

for node in dt.descendants:
ds = node.ds
group_path = node.pathstr.replace(dt.root.pathstr, "")
if ds is None:
_create_empty_group(filepath, group_path, mode)
else:
ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr),
**kwargs
)
24 changes: 21 additions & 3 deletions datatree/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from xarray.testing import assert_identical

from datatree import DataNode, DataTree
from datatree.io import open_datatree


def create_test_datatree():
Expand Down Expand Up @@ -31,14 +32,14 @@ def create_test_datatree():
| Dimensions: (x: 2, y: 3)
| Data variables:
| a (y) int64 6, 7, 8
| set1 (x) int64 9, 10
| set0 (x) int64 9, 10

The structure has deliberately repeated names of tags, variables, and
dimensions in order to better check for bugs caused by name conflicts.
"""
set1_data = xr.Dataset({"a": 0, "b": 1})
set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set1": ("x", [9, 10])})
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved

# Avoid using __init__ so we can independently test it
# TODO change so it has a DataTree at the bottom
Expand Down Expand Up @@ -297,4 +298,21 @@ def test_repr_of_node_with_data(self):


class TestIO:
...
def test_to_netcdf(self, tmpdir):
filepath = str(
tmpdir / "test.nc"
) # casting to str avoids a pathlib bug in xarray
original_dt = create_test_datatree()
original_dt.to_netcdf(filepath, engine="netcdf4")

roundtrip_dt = open_datatree(filepath)

original_dt.name == roundtrip_dt.name
assert original_dt.ds.identical(roundtrip_dt.ds)
for a, b in zip(original_dt.descendants, roundtrip_dt.descendants):
assert a.name == b.name
assert a.pathstr == b.pathstr
if a.has_data:
assert a.ds.identical(b.ds)
else:
assert a.ds is b.ds