From f9d7ac8c955aebb2e24a58c0fa166ad3393ccc5a Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Mon, 31 Jul 2023 19:24:04 -0400 Subject: [PATCH] Basic pluggable cache mechanism (#167) --- pyproject.toml | 1 + xbatcher/generators.py | 49 ++++++++++++++++++++++++++-- xbatcher/tests/test_generators.py | 54 +++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6b1079a..73483f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ "pytest-cov", "tensorflow", "torch", + "zarr", ] [project.urls] documentation = "https://xbatcher.readthedocs.io/en/latest/" diff --git a/xbatcher/generators.py b/xbatcher/generators.py index f1f4a78..b5edff0 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -3,7 +3,17 @@ import itertools import warnings from operator import itemgetter -from typing import Any, Dict, Hashable, Iterator, List, Optional, Sequence, Union +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterator, + List, + Optional, + Sequence, + Union, +) import numpy as np import xarray as xr @@ -364,6 +374,12 @@ class BatchGenerator: preload_batch : bool, optional If ``True``, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed. + cache : dict, optional + Dict-like object to cache batches in (e.g., Zarr DirectoryStore). Note: + The caching API is experimental and subject to change. + cache_preprocess: callable, optional + A function to apply to batches prior to caching. + Note: The caching API is experimental and subject to change. Yields ------ @@ -379,8 +395,13 @@ def __init__( batch_dims: Dict[Hashable, int] = {}, concat_input_dims: bool = False, preload_batch: bool = True, + cache: Optional[Dict[str, Any]] = None, + cache_preprocess: Optional[Callable] = None, ): self.ds = ds + self.cache = cache + self.cache_preprocess = cache_preprocess + self._batch_selectors: BatchSchema = BatchSchema( ds, input_dims=input_dims, @@ -426,6 +447,9 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: if idx < 0: idx = list(self._batch_selectors.selectors)[idx] + if self.cache and self._batch_in_cache(idx): + return self._get_cached_batch(idx) + if idx in self._batch_selectors.selectors: if self.concat_input_dims: new_dim_suffix = "_input" @@ -451,14 +475,33 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: ) dsc = xr.concat(all_dsets, dim="input_batch") new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] - return _maybe_stack_batch_dims(dsc, new_input_dims) + batch = _maybe_stack_batch_dims(dsc, new_input_dims) else: batch_ds = self.ds.isel(self._batch_selectors.selectors[idx][0]) if self.preload_batch: batch_ds.load() - return _maybe_stack_batch_dims( + batch = _maybe_stack_batch_dims( batch_ds, list(self.input_dims), ) else: raise IndexError("list index out of range") + + if self.cache is not None and self.cache_preprocess is not None: + batch = self.cache_preprocess(batch) + if self.cache is not None: + self._cache_batch(idx, batch) + + return batch + + def _batch_in_cache(self, idx: int) -> bool: + return self.cache is not None and f"{idx}/.zgroup" in self.cache + + def _cache_batch(self, idx: int, batch: Union[xr.Dataset, xr.DataArray]) -> None: + batch.to_zarr(self.cache, group=str(idx), mode="a") + + def _get_cached_batch(self, idx: int) -> xr.Dataset: + ds = xr.open_zarr(self.cache, group=str(idx)) + if self.preload_batch: + ds = ds.load() + return ds diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 3a9f98f..248dd03 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import numpy as np import pytest import xarray as xr @@ -356,3 +358,55 @@ def test_input_overlap_exceptions(sample_ds_1d): with pytest.raises(ValueError) as e: BatchGenerator(sample_ds_1d, input_dims={"x": 10}, input_overlap={"x": 20}) assert len(e) == 1 + + +@pytest.mark.parametrize("preload", [True, False]) +def test_batcher_cached_getitem(sample_ds_1d, preload) -> None: + pytest.importorskip("zarr") + cache: Dict[str, Any] = {} + + def preproc(ds): + processed = ds.load().chunk(-1) + processed.attrs["foo"] = "bar" + return processed + + bg = BatchGenerator( + sample_ds_1d, + input_dims={"x": 10}, + cache=cache, + cache_preprocess=preproc, + preload_batch=preload, + ) + + # first batch + assert bg[0].sizes["x"] == 10 + ds_no_cache = bg[1] + # last batch + assert bg[-1].sizes["x"] == 10 + + assert "0/.zgroup" in cache + + # now from cache + # first batch + assert bg[0].sizes["x"] == 10 + # last batch + assert bg[-1].sizes["x"] == 10 + ds_cache = bg[1] + + assert ds_no_cache.attrs["foo"] == "bar" + assert ds_cache.attrs["foo"] == "bar" + + xr.testing.assert_equal(ds_no_cache, ds_cache) + xr.testing.assert_identical(ds_no_cache, ds_cache) + + # without preprocess func + bg = BatchGenerator( + sample_ds_1d, input_dims={"x": 10}, cache=cache, preload_batch=preload + ) + assert bg.cache_preprocess is None + assert bg[0].sizes["x"] == 10 + ds_no_cache = bg[1] + assert "1/.zgroup" in cache + ds_cache = bg[1] + xr.testing.assert_equal(ds_no_cache, ds_cache) + xr.testing.assert_identical(ds_no_cache, ds_cache)