Skip to content

Commit

Permalink
Unit tests for closing functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Oct 21, 2024
1 parent 5b793b8 commit 994550f
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 15 deletions.
5 changes: 3 additions & 2 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,16 @@ def find_root_and_group(ds):
return ds, group


def datatree_from_io_dict(groups_dict: Mapping[str, Dataset]) -> DataTree:
def datatree_from_dict_with_io_cleanup(groups_dict: Mapping[str, Dataset]) -> DataTree:
"""DataTree.from_dict with file clean-up."""
try:
tree = DataTree.from_dict(groups_dict)
except Exception:
for ds in groups_dict.values():
ds.close()
raise
tree.set_close({path: ds._close for path, ds in groups_dict.items()})
for path, ds in groups_dict.items():
tree[path].set_close(ds._close)
return tree


Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
BackendEntrypoint,
WritableCFDataStore,
_normalize_path,
datatree_from_io_dict,
datatree_from_dict_with_io_cleanup,
find_root_and_group,
)
from xarray.backends.file_manager import CachingFileManager, DummyFileManager
Expand Down Expand Up @@ -494,7 +494,7 @@ def open_datatree(
driver_kwds=driver_kwds,
**kwargs,
)
return datatree_from_io_dict(groups_dict)
return datatree_from_dict_with_io_cleanup(groups_dict)

def open_groups_as_dict(
self,
Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
BackendEntrypoint,
WritableCFDataStore,
_normalize_path,
datatree_from_io_dict,
datatree_from_dict_with_io_cleanup,
find_root_and_group,
robust_getitem,
)
Expand Down Expand Up @@ -729,7 +729,7 @@ def open_datatree(
autoclose=autoclose,
**kwargs,
)
return datatree_from_io_dict(groups_dict)
return datatree_from_dict_with_io_cleanup(groups_dict)

def open_groups_as_dict(
self,
Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
BackendEntrypoint,
_encode_variable_name,
_normalize_path,
datatree_from_io_dict,
datatree_from_dict_with_io_cleanup,
)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core import indexing
Expand Down Expand Up @@ -1311,7 +1311,7 @@ def open_datatree(
zarr_version=zarr_version,
**kwargs,
)
return datatree_from_io_dict(groups_dict)
return datatree_from_dict_with_io_cleanup(groups_dict)

def open_groups_as_dict(
self,
Expand Down
32 changes: 25 additions & 7 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,15 @@ def update(self, other) -> NoReturn:
"use `.copy()` first to get a mutable version of the input dataset."
)

def set_close(self, close: Callable[[], None] | None) -> None:
raise AttributeError("cannot modify a DatasetView()")

def close(self) -> None:
raise AttributeError(
"cannot close a DatasetView(). Close the associated DataTree node "
"instead"
)

# FIXME https://github.com/python/mypy/issues/7328
@overload # type: ignore[override]
def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[overload-overlap]
Expand Down Expand Up @@ -580,7 +589,7 @@ def _to_dataset_view(self, rebuild_dims: bool, inherit: bool) -> DatasetView:
attrs=self._attrs,
indexes=dict(self._indexes if inherit else self._node_indexes),
encoding=self._encoding,
close=self._close,
close=None,
)

@property
Expand Down Expand Up @@ -633,7 +642,7 @@ def to_dataset(self, inherit: bool = True) -> Dataset:
None if self._attrs is None else dict(self._attrs),
dict(self._indexes if inherit else self._node_indexes),
None if self._encoding is None else dict(self._encoding),
self._close,
None,
)

@property
Expand Down Expand Up @@ -802,13 +811,22 @@ def __enter__(self) -> Self:
def __exit__(self, exc_type, exc_value, traceback) -> None:
self.close()

def close(self):
# DatasetView does not support close() or set_close(), so we reimplement
# these methods on DataTree.

def _close_node(self) -> None:
if self._close is not None:
self._close()
self._close = None

def close(self) -> None:
"""Close any files associated with this tree."""
for node in self.subtree:
node.dataset.close()
node._close_node()

def set_close(self, closers: Mapping[str, Callable[[], None] | None], /) -> None:
for path, close in closers.items():
self[path]._close = close
def set_close(self, close: Callable[[], None] | None) -> None:
"""Set the closer for this node."""
self._close = close

def _replace_node(
self: DataTree,
Expand Down
72 changes: 72 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2123,3 +2123,75 @@ def test_tree(self, create_test_datatree):
expected = create_test_datatree(modify=lambda ds: np.sin(ds))
result_tree = np.sin(dt)
assert_equal(result_tree, expected)


class Closer:
def __init__(self):
self.closed = False

def close(self):
if self.closed:
raise RuntimeError("already closed")
self.closed = True


@pytest.fixture()
def tree_and_closers():
tree = DataTree.from_dict({"/child/grandchild": None})
closers = {
"/": Closer(),
"/child": Closer(),
"/child/grandchild": Closer(),
}
for path, closer in closers.items():
tree[path].set_close(closer.close)
return tree, closers


class TestClose:
def test_close(self, tree_and_closers):
tree, closers = tree_and_closers
assert not any(closer.closed for closer in closers.values())
tree.close()
assert all(closer.closed for closer in closers.values())
tree.close() # should not error

def test_context_manager(self, tree_and_closers):
tree, closers = tree_and_closers
assert not any(closer.closed for closer in closers.values())
with tree:
pass
assert all(closer.closed for closer in closers.values())

def test_close_child(self, tree_and_closers):
tree, closers = tree_and_closers
assert not any(closer.closed for closer in closers.values())
tree["child"].close() # should only close descendants
assert not closers["/"].closed
assert closers["/child"].closed
assert closers["/child/grandchild"].closed

def test_close_datasetview(self, tree_and_closers):
tree, _ = tree_and_closers

with pytest.raises(
AttributeError,
match=re.escape(
r"cannot close a DatasetView(). Close the associated DataTree node instead"
),
):
tree.dataset.close()

with pytest.raises(
AttributeError, match=re.escape(r"cannot modify a DatasetView()")
):
tree.dataset.set_close(None)

def test_close_dataset(self, tree_and_closers):
tree, closers = tree_and_closers
ds = tree.to_dataset() # should discard closers
ds.close()
assert not closers["/"].closed

# with tree:
# pass

0 comments on commit 994550f

Please sign in to comment.