Skip to content

Commit

Permalink
Basic pluggable cache mechanism (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxrjones authored Jul 31, 2023
1 parent f614ea5 commit f9d7ac8
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 3 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dev = [
"pytest-cov",
"tensorflow",
"torch",
"zarr",
]
[project.urls]
documentation = "https://xbatcher.readthedocs.io/en/latest/"
Expand Down
49 changes: 46 additions & 3 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
import itertools
import warnings
from operator import itemgetter
from typing import Any, Dict, Hashable, Iterator, List, Optional, Sequence, Union
from typing import (
Any,
Callable,
Dict,
Hashable,
Iterator,
List,
Optional,
Sequence,
Union,
)

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -364,6 +374,12 @@ class BatchGenerator:
preload_batch : bool, optional
If ``True``, each batch will be loaded into memory before reshaping /
processing, triggering any dask arrays to be computed.
cache : dict, optional
Dict-like object to cache batches in (e.g., Zarr DirectoryStore). Note:
The caching API is experimental and subject to change.
cache_preprocess: callable, optional
A function to apply to batches prior to caching.
Note: The caching API is experimental and subject to change.
Yields
------
Expand All @@ -379,8 +395,13 @@ def __init__(
batch_dims: Dict[Hashable, int] = {},
concat_input_dims: bool = False,
preload_batch: bool = True,
cache: Optional[Dict[str, Any]] = None,
cache_preprocess: Optional[Callable] = None,
):
self.ds = ds
self.cache = cache
self.cache_preprocess = cache_preprocess

self._batch_selectors: BatchSchema = BatchSchema(
ds,
input_dims=input_dims,
Expand Down Expand Up @@ -426,6 +447,9 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]:
if idx < 0:
idx = list(self._batch_selectors.selectors)[idx]

if self.cache and self._batch_in_cache(idx):
return self._get_cached_batch(idx)

if idx in self._batch_selectors.selectors:
if self.concat_input_dims:
new_dim_suffix = "_input"
Expand All @@ -451,14 +475,33 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]:
)
dsc = xr.concat(all_dsets, dim="input_batch")
new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims]
return _maybe_stack_batch_dims(dsc, new_input_dims)
batch = _maybe_stack_batch_dims(dsc, new_input_dims)
else:
batch_ds = self.ds.isel(self._batch_selectors.selectors[idx][0])
if self.preload_batch:
batch_ds.load()
return _maybe_stack_batch_dims(
batch = _maybe_stack_batch_dims(
batch_ds,
list(self.input_dims),
)
else:
raise IndexError("list index out of range")

if self.cache is not None and self.cache_preprocess is not None:
batch = self.cache_preprocess(batch)
if self.cache is not None:
self._cache_batch(idx, batch)

return batch

def _batch_in_cache(self, idx: int) -> bool:
return self.cache is not None and f"{idx}/.zgroup" in self.cache

def _cache_batch(self, idx: int, batch: Union[xr.Dataset, xr.DataArray]) -> None:
batch.to_zarr(self.cache, group=str(idx), mode="a")

def _get_cached_batch(self, idx: int) -> xr.Dataset:
ds = xr.open_zarr(self.cache, group=str(idx))
if self.preload_batch:
ds = ds.load()
return ds
54 changes: 54 additions & 0 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict

import numpy as np
import pytest
import xarray as xr
Expand Down Expand Up @@ -356,3 +358,55 @@ def test_input_overlap_exceptions(sample_ds_1d):
with pytest.raises(ValueError) as e:
BatchGenerator(sample_ds_1d, input_dims={"x": 10}, input_overlap={"x": 20})
assert len(e) == 1


@pytest.mark.parametrize("preload", [True, False])
def test_batcher_cached_getitem(sample_ds_1d, preload) -> None:
pytest.importorskip("zarr")
cache: Dict[str, Any] = {}

def preproc(ds):
processed = ds.load().chunk(-1)
processed.attrs["foo"] = "bar"
return processed

bg = BatchGenerator(
sample_ds_1d,
input_dims={"x": 10},
cache=cache,
cache_preprocess=preproc,
preload_batch=preload,
)

# first batch
assert bg[0].sizes["x"] == 10
ds_no_cache = bg[1]
# last batch
assert bg[-1].sizes["x"] == 10

assert "0/.zgroup" in cache

# now from cache
# first batch
assert bg[0].sizes["x"] == 10
# last batch
assert bg[-1].sizes["x"] == 10
ds_cache = bg[1]

assert ds_no_cache.attrs["foo"] == "bar"
assert ds_cache.attrs["foo"] == "bar"

xr.testing.assert_equal(ds_no_cache, ds_cache)
xr.testing.assert_identical(ds_no_cache, ds_cache)

# without preprocess func
bg = BatchGenerator(
sample_ds_1d, input_dims={"x": 10}, cache=cache, preload_batch=preload
)
assert bg.cache_preprocess is None
assert bg[0].sizes["x"] == 10
ds_no_cache = bg[1]
assert "1/.zgroup" in cache
ds_cache = bg[1]
xr.testing.assert_equal(ds_no_cache, ds_cache)
xr.testing.assert_identical(ds_no_cache, ds_cache)

0 comments on commit f9d7ac8

Please sign in to comment.