Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add close() method to DataTree and use it to clean-up open files in tests #9651

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import os
import time
import traceback
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from glob import glob
from typing import TYPE_CHECKING, Any, ClassVar

import numpy as np

from xarray.conventions import cf_encoder
from xarray.core import indexing
from xarray.core.datatree import DataTree
from xarray.core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
Expand All @@ -20,7 +21,6 @@
from io import BufferedIOBase

from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
from xarray.core.types import NestedSequence

# Create a logger object, but don't add any handlers. Leave that to user code.
Expand Down Expand Up @@ -149,6 +149,18 @@ def find_root_and_group(ds):
return ds, group


def datatree_from_io_dict(groups_dict: Mapping[str, Dataset]) -> DataTree:
"""DataTree.from_dict with file clean-up."""
try:
tree = DataTree.from_dict(groups_dict)
except Exception:
for ds in groups_dict.values():
ds.close()
raise
tree.set_close({path: ds._close for path, ds in groups_dict.items()})
return tree


def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500):
"""
Robustly index an array, using retry logic with exponential backoff if any
Expand Down
6 changes: 2 additions & 4 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
BackendEntrypoint,
WritableCFDataStore,
_normalize_path,
datatree_from_io_dict,
find_root_and_group,
)
from xarray.backends.file_manager import CachingFileManager, DummyFileManager
Expand Down Expand Up @@ -474,8 +475,6 @@ def open_datatree(
driver_kwds=None,
**kwargs,
) -> DataTree:
from xarray.core.datatree import DataTree

groups_dict = self.open_groups_as_dict(
filename_or_obj,
mask_and_scale=mask_and_scale,
Expand All @@ -495,8 +494,7 @@ def open_datatree(
driver_kwds=driver_kwds,
**kwargs,
)

return DataTree.from_dict(groups_dict)
return datatree_from_io_dict(groups_dict)

def open_groups_as_dict(
self,
Expand Down
6 changes: 2 additions & 4 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BackendEntrypoint,
WritableCFDataStore,
_normalize_path,
datatree_from_io_dict,
find_root_and_group,
robust_getitem,
)
Expand Down Expand Up @@ -710,8 +711,6 @@ def open_datatree(
autoclose=False,
**kwargs,
) -> DataTree:
from xarray.core.datatree import DataTree

groups_dict = self.open_groups_as_dict(
filename_or_obj,
mask_and_scale=mask_and_scale,
Expand All @@ -730,8 +729,7 @@ def open_datatree(
autoclose=autoclose,
**kwargs,
)

return DataTree.from_dict(groups_dict)
return datatree_from_io_dict(groups_dict)

def open_groups_as_dict(
self,
Expand Down
6 changes: 2 additions & 4 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
BackendEntrypoint,
_encode_variable_name,
_normalize_path,
datatree_from_io_dict,
)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core import indexing
Expand Down Expand Up @@ -1290,8 +1291,6 @@ def open_datatree(
zarr_version=None,
**kwargs,
) -> DataTree:
from xarray.core.datatree import DataTree

filename_or_obj = _normalize_path(filename_or_obj)
groups_dict = self.open_groups_as_dict(
filename_or_obj=filename_or_obj,
Expand All @@ -1312,8 +1311,7 @@ def open_datatree(
zarr_version=zarr_version,
**kwargs,
)

return DataTree.from_dict(groups_dict)
return datatree_from_io_dict(groups_dict)

def open_groups_as_dict(
self,
Expand Down
16 changes: 15 additions & 1 deletion xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def _to_dataset_view(self, rebuild_dims: bool, inherit: bool) -> DatasetView:
attrs=self._attrs,
indexes=dict(self._indexes if inherit else self._node_indexes),
encoding=self._encoding,
close=None,
close=self._close,
)

@property
Expand Down Expand Up @@ -796,6 +796,20 @@ def _repr_html_(self):
return f"<pre>{escape(repr(self))}</pre>"
return datatree_repr_html(self)

def __enter__(self) -> Self:
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
self.close()

def close(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def close(self):
def close(self) -> None:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for node in self.subtree:
node.dataset.close()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to add a test to verify that calling close() repeatedly does not raise an error.

(Dataset.close is not idempotent, because it replaces _close with None, but the dataset objects on which I'm calling it here are not persistent. Probably I should also make DatasetView.close raise an error.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def set_close(self, closers: Mapping[str, Callable[[], None] | None], /) -> None:
for path, close in closers.items():
self[path]._close = close
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to only set the closer on the root node, similar to Dataset.set_close().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you can presumably pull out the dataset from that node and .set_close on that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The datasets created by DataTree.dataset are ephemeral, so that wouldn't work.

(I think I should probably change this to only act at the local node level)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm modified set_close to only work at the local node level, but close() still closes everything in the subtree.


def _replace_node(
self: DataTree,
data: Dataset | Default = _default,
Expand Down
130 changes: 68 additions & 62 deletions xarray/tests/test_backends_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ def test_to_netcdf(self, tmpdir, simple_datatree):
original_dt = simple_datatree
original_dt.to_netcdf(filepath, engine=self.engine)

roundtrip_dt = open_datatree(filepath, engine=self.engine)
assert_equal(original_dt, roundtrip_dt)
with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
assert roundtrip_dt._close is not None
assert_equal(original_dt, roundtrip_dt)

def test_to_netcdf_inherited_coords(self, tmpdir):
filepath = tmpdir / "test.nc"
Expand All @@ -128,10 +129,10 @@ def test_to_netcdf_inherited_coords(self, tmpdir):
)
original_dt.to_netcdf(filepath, engine=self.engine)

roundtrip_dt = open_datatree(filepath, engine=self.engine)
assert_equal(original_dt, roundtrip_dt)
subtree = cast(DataTree, roundtrip_dt["/sub"])
assert "x" not in subtree.to_dataset(inherit=False).coords
with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
subtree = cast(DataTree, roundtrip_dt["/sub"])
assert "x" not in subtree.to_dataset(inherit=False).coords

def test_netcdf_encoding(self, tmpdir, simple_datatree):
filepath = tmpdir / "test.nc"
Expand All @@ -142,14 +143,13 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):
enc = {"/set2": {var: comp for var in original_dt["/set2"].dataset.data_vars}}

original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)
roundtrip_dt = open_datatree(filepath, engine=self.engine)
with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
assert roundtrip_dt["/set2/a"].encoding["zlib"] == comp["zlib"]
assert roundtrip_dt["/set2/a"].encoding["complevel"] == comp["complevel"]

assert roundtrip_dt["/set2/a"].encoding["zlib"] == comp["zlib"]
assert roundtrip_dt["/set2/a"].encoding["complevel"] == comp["complevel"]

enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item]
with pytest.raises(ValueError, match="unexpected encoding group.*"):
original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)
enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item]
with pytest.raises(ValueError, match="unexpected encoding group.*"):
original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)


@requires_netCDF4
Expand Down Expand Up @@ -179,18 +179,17 @@ def test_open_groups(self, unaligned_datatree_nc) -> None:
assert "/Group1" in unaligned_dict_of_datasets.keys()
assert "/Group1/subgroup1" in unaligned_dict_of_datasets.keys()
# Check that group name returns the correct datasets
assert_identical(
unaligned_dict_of_datasets["/"],
xr.open_dataset(unaligned_datatree_nc, group="/"),
)
assert_identical(
unaligned_dict_of_datasets["/Group1"],
xr.open_dataset(unaligned_datatree_nc, group="Group1"),
)
assert_identical(
unaligned_dict_of_datasets["/Group1/subgroup1"],
xr.open_dataset(unaligned_datatree_nc, group="/Group1/subgroup1"),
)
with xr.open_dataset(unaligned_datatree_nc, group="/") as expected:
assert_identical(unaligned_dict_of_datasets["/"], expected)
with xr.open_dataset(unaligned_datatree_nc, group="Group1") as expected:
assert_identical(unaligned_dict_of_datasets["/Group1"], expected)
with xr.open_dataset(
unaligned_datatree_nc, group="/Group1/subgroup1"
) as expected:
assert_identical(unaligned_dict_of_datasets["/Group1/subgroup1"], expected)

for ds in unaligned_dict_of_datasets.values():
ds.close()

def test_open_groups_to_dict(self, tmpdir) -> None:
"""Create an aligned netCDF4 with the following structure to test `open_groups`
Expand Down Expand Up @@ -234,8 +233,10 @@ def test_open_groups_to_dict(self, tmpdir) -> None:

aligned_dict_of_datasets = open_groups(filepath)
aligned_dt = DataTree.from_dict(aligned_dict_of_datasets)

assert open_datatree(filepath).identical(aligned_dt)
with open_datatree(filepath) as opened_tree:
assert opened_tree.identical(aligned_dt)
for ds in aligned_dict_of_datasets.values():
ds.close()


@requires_h5netcdf
Expand All @@ -252,8 +253,8 @@ def test_to_zarr(self, tmpdir, simple_datatree):
original_dt = simple_datatree
original_dt.to_zarr(filepath)

roundtrip_dt = open_datatree(filepath, engine="zarr")
assert_equal(original_dt, roundtrip_dt)
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)

def test_zarr_encoding(self, tmpdir, simple_datatree):
import zarr
Expand All @@ -264,14 +265,14 @@ def test_zarr_encoding(self, tmpdir, simple_datatree):
comp = {"compressor": zarr.Blosc(cname="zstd", clevel=3, shuffle=2)}
enc = {"/set2": {var: comp for var in original_dt["/set2"].dataset.data_vars}}
original_dt.to_zarr(filepath, encoding=enc)
roundtrip_dt = open_datatree(filepath, engine="zarr")

print(roundtrip_dt["/set2/a"].encoding)
assert roundtrip_dt["/set2/a"].encoding["compressor"] == comp["compressor"]
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
print(roundtrip_dt["/set2/a"].encoding)
assert roundtrip_dt["/set2/a"].encoding["compressor"] == comp["compressor"]

enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item]
with pytest.raises(ValueError, match="unexpected encoding group.*"):
original_dt.to_zarr(filepath, encoding=enc, engine="zarr")
enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item]
with pytest.raises(ValueError, match="unexpected encoding group.*"):
original_dt.to_zarr(filepath, encoding=enc, engine="zarr")

def test_to_zarr_zip_store(self, tmpdir, simple_datatree):
from zarr.storage import ZipStore
Expand All @@ -281,8 +282,8 @@ def test_to_zarr_zip_store(self, tmpdir, simple_datatree):
store = ZipStore(filepath)
original_dt.to_zarr(store)

roundtrip_dt = open_datatree(store, engine="zarr")
assert_equal(original_dt, roundtrip_dt)
with open_datatree(store, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)

def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree):
filepath = tmpdir / "test.zarr"
Expand All @@ -295,8 +296,8 @@ def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree):
assert not s1zmetadata.exists()

with pytest.warns(RuntimeWarning, match="consolidated"):
roundtrip_dt = open_datatree(filepath, engine="zarr")
assert_equal(original_dt, roundtrip_dt)
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)

def test_to_zarr_default_write_mode(self, tmpdir, simple_datatree):
import zarr
Expand All @@ -317,10 +318,10 @@ def test_to_zarr_inherited_coords(self, tmpdir):
filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath)

roundtrip_dt = open_datatree(filepath, engine="zarr")
assert_equal(original_dt, roundtrip_dt)
subtree = cast(DataTree, roundtrip_dt["/sub"])
assert "x" not in subtree.to_dataset(inherit=False).coords
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
subtree = cast(DataTree, roundtrip_dt["/sub"])
assert "x" not in subtree.to_dataset(inherit=False).coords

def test_open_groups_round_trip(self, tmpdir, simple_datatree) -> None:
"""Test `open_groups` opens a zarr store with the `simple_datatree` structure."""
Expand All @@ -331,7 +332,11 @@ def test_open_groups_round_trip(self, tmpdir, simple_datatree) -> None:
roundtrip_dict = open_groups(filepath, engine="zarr")
roundtrip_dt = DataTree.from_dict(roundtrip_dict)

assert open_datatree(filepath, engine="zarr").identical(roundtrip_dt)
with open_datatree(filepath, engine="zarr") as opened_tree:
assert opened_tree.identical(roundtrip_dt)

for ds in roundtrip_dict.values():
ds.close()

def test_open_datatree(self, unaligned_datatree_zarr) -> None:
"""Test if `open_datatree` fails to open a zarr store with an unaligned group hierarchy."""
Expand All @@ -353,21 +358,22 @@ def test_open_groups(self, unaligned_datatree_zarr) -> None:
assert "/Group1/subgroup1" in unaligned_dict_of_datasets.keys()
assert "/Group2" in unaligned_dict_of_datasets.keys()
# Check that group name returns the correct datasets
assert_identical(
unaligned_dict_of_datasets["/"],
xr.open_dataset(unaligned_datatree_zarr, group="/", engine="zarr"),
)
assert_identical(
unaligned_dict_of_datasets["/Group1"],
xr.open_dataset(unaligned_datatree_zarr, group="Group1", engine="zarr"),
)
assert_identical(
unaligned_dict_of_datasets["/Group1/subgroup1"],
xr.open_dataset(
unaligned_datatree_zarr, group="/Group1/subgroup1", engine="zarr"
),
)
assert_identical(
unaligned_dict_of_datasets["/Group2"],
xr.open_dataset(unaligned_datatree_zarr, group="/Group2", engine="zarr"),
)
with xr.open_dataset(
unaligned_datatree_zarr, group="/", engine="zarr"
) as expected:
assert_identical(unaligned_dict_of_datasets["/"], expected)
with xr.open_dataset(
unaligned_datatree_zarr, group="Group1", engine="zarr"
) as expected:
assert_identical(unaligned_dict_of_datasets["/Group1"], expected)
with xr.open_dataset(
unaligned_datatree_zarr, group="/Group1/subgroup1", engine="zarr"
) as expected:
assert_identical(unaligned_dict_of_datasets["/Group1/subgroup1"], expected)
with xr.open_dataset(
unaligned_datatree_zarr, group="/Group2", engine="zarr"
) as expected:
assert_identical(unaligned_dict_of_datasets["/Group2"], expected)

for ds in unaligned_dict_of_datasets.values():
ds.close()
Loading