Skip to content
This repository has been archived by the owner on Oct 24, 2024. It is now read-only.

[WIP] add DataTree.to_netcdf #26

Merged
merged 5 commits into from
Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,10 +874,44 @@ def groups(self):
"""Return all netCDF4 groups in the tree, given as a tuple of path-like strings."""
return tuple(node.pathstr for node in self.subtree_nodes)

def to_netcdf(self, filename: str):
def to_netcdf(
self, filepath, mode: str = "w", encoding=None, unlimited_dims=None, **kwargs
):
"""
Write datatree contents to a netCDF file.

Paramters
---------
filepath : str or Path
Path to which to save this datatree.
mode : {"w", "a"}, default: "w"
Write ('w') or append ('a') mode. If mode='w', any existing file at
this location will be overwritten. If mode='a', existing variables
will be overwritten. Only appies to the root group.
encoding : dict, optional
Nested dictionary with variable names as keys and dictionaries of
variable specific encodings as values, e.g.,
``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1,
"zlib": True}, ...}, ...}``. See ``xarray.Dataset.to_netcdf`` for available
options.
unlimited_dims : dict, optional
Mapping of unlimited dimensions per group that that should be serialized as unlimited dimensions.
By default, no dimensions are treated as unlimited dimensions.
Note that unlimited_dims may also be set via
``dataset.encoding["unlimited_dims"]``.
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
kwargs :
Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf``
"""
from .io import _datatree_to_netcdf

_datatree_to_netcdf(self, filename)
_datatree_to_netcdf(
self,
filepath,
mode=mode,
encoding=encoding,
unlimited_dims=unlimited_dims,
**kwargs,
)

def plot(self):
raise NotImplementedError
Expand Down
62 changes: 59 additions & 3 deletions datatree/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict, Sequence

import netCDF4
from xarray import open_dataset
from xarray import Dataset, open_dataset

from .datatree import DataNode, DataTree, PathType

Expand Down Expand Up @@ -65,5 +65,61 @@ def open_mfdatatree(
return full_tree


def _datatree_to_netcdf(dt: DataTree, filepath: str):
raise NotImplementedError
def _maybe_extract_group_kwargs(enc, group):
try:
return enc[group]
except KeyError:
return None


def _datatree_to_netcdf(
dt: DataTree,
filepath,
mode: str = "w",
encoding=None,
unlimited_dims=None,
**kwargs
):

if kwargs.get("format", None) not in [None, "NETCDF4"]:
raise ValueError("to_netcdf only supports the NETCDF4 format")

if kwargs.get("engine", None) not in [None, "netcdf4", "h5netcdf"]:
raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines")

if not kwargs.get("compute", True):
raise NotImplementedError("compute=False has not been implemented yet")

if encoding is None:
encoding = {}

if unlimited_dims is None:
unlimited_dims = {}

ds = dt.ds
if ds is None:
ds = Dataset()
group_path = dt.pathstr.replace(dt.root.pathstr, "")
ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr),
**kwargs
)
mode = "a"

for node in dt.descendants:
ds = node.ds
if ds is None:
ds = Dataset()
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
group_path = node.pathstr.replace(dt.root.pathstr, "")
ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr),
unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr),
**kwargs
)
22 changes: 19 additions & 3 deletions datatree/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from xarray.testing import assert_identical

from datatree import DataNode, DataTree
from datatree.io import open_datatree


def create_test_datatree():
Expand Down Expand Up @@ -31,14 +32,14 @@ def create_test_datatree():
| Dimensions: (x: 2, y: 3)
| Data variables:
| a (y) int64 6, 7, 8
| set1 (x) int64 9, 10
| set0 (x) int64 9, 10

The structure has deliberately repeated names of tags, variables, and
dimensions in order to better check for bugs caused by name conflicts.
"""
set1_data = xr.Dataset({"a": 0, "b": 1})
set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set1": ("x", [9, 10])})
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved

# Avoid using __init__ so we can independently test it
# TODO change so it has a DataTree at the bottom
Expand Down Expand Up @@ -297,4 +298,19 @@ def test_repr_of_node_with_data(self):


class TestIO:
...
def test_to_netcdf(self, tmpdir):
filepath = str(
tmpdir / "test.nc"
) # casting to str avoids a pathlib bug in xarray
original_dt = create_test_datatree()
original_dt.to_netcdf(filepath, engine="netcdf4")

roundtrip_dt = open_datatree(filepath)

# Q: why does the roundtrip_dt have an extra `root/` in all paths
print([n.pathstr for n in original_dt.subtree_nodes])
print([n.pathstr for n in roundtrip_dt.subtree_nodes])
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved

assert original_dt.ds.identical(roundtrip_dt.ds)
for node in original_dt.descendants:
assert node.ds.identical(roundtrip_dt[node.pathstr])
jhamman marked this conversation as resolved.
Show resolved Hide resolved