diff --git a/src/anemoi/utils/caching.py b/src/anemoi/utils/caching.py index c628b44..677eced 100644 --- a/src/anemoi/utils/caching.py +++ b/src/anemoi/utils/caching.py @@ -14,11 +14,47 @@ import time from threading import Lock +import numpy as np + LOCK = Lock() CACHE = {} -def cache(key, proc, collection="default", expires=None): +def _json_save(path, data): + with open(path, "w") as f: + json.dump(data, f) + + +def _json_load(path): + with open(path, "r") as f: + return json.load(f) + + +def _npz_save(path, data): + return np.savez(path, **data) + + +def _npz_load(path): + return np.load(path, allow_pickle=True) + + +def _get_cache_path(collection): + return os.path.join(os.path.expanduser("~"), ".cache", "anemoi", collection) + + +def clean_cache(collection="default"): + path = _get_cache_path(collection) + if not os.path.exists(path): + return + for filename in os.listdir(path): + os.remove(os.path.join(path, filename)) + + +def cache(key, proc, collection="default", expires=None, encoding="json"): + load, save, ext = dict( + json=(_json_load, _json_save, ""), + npz=(_npz_load, _npz_save, ".npz"), + )[encoding] key = json.dumps(key, sort_keys=True) m = hashlib.md5() @@ -28,24 +64,22 @@ def cache(key, proc, collection="default", expires=None): if m in CACHE: return CACHE[m] - path = os.path.join(os.path.expanduser("~"), ".cache", "anemoi", collection) + path = _get_cache_path(collection) os.makedirs(path, exist_ok=True) - filename = os.path.join(path, m) + filename = os.path.join(path, m) + ext if os.path.exists(filename): - with open(filename, "r") as f: - data = json.load(f) - if expires is None or data["expires"] > time.time(): - if data["key"] == key: - return data["value"] + data = load(filename) + if expires is None or data["expires"] > time.time(): + if data["key"] == key: + return data["value"] value = proc() data = {"key": key, "value": value} if expires is not None: data["expires"] = time.time() + expires - with open(filename, "w") as f: - json.dump(data, f) + save(filename, data) CACHE[m] = value return value @@ -54,9 +88,10 @@ def cache(key, proc, collection="default", expires=None): class cached: """Decorator to cache the result of a function.""" - def __init__(self, collection="default", expires=None): + def __init__(self, collection="default", expires=None, encoding="json"): self.collection = collection self.expires = expires + self.encoding = encoding def __call__(self, func): @@ -69,6 +104,7 @@ def wrapped(*args, **kwargs): lambda: func(*args, **kwargs), self.collection, self.expires, + self.encoding, ) return wrapped diff --git a/src/anemoi/utils/grids.py b/src/anemoi/utils/grids.py new file mode 100644 index 0000000..1ab1c74 --- /dev/null +++ b/src/anemoi/utils/grids.py @@ -0,0 +1,42 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +"""Utilities for working with grids. + +""" + +import logging +from io import BytesIO + +import numpy as np +import requests + +from .caching import cached + +LOG = logging.getLogger(__name__) + + +GRIDS_URL_PATTERN = "https://get.ecmwf.int/repository/anemoi/grids/grid-{name}.npz" + + +@cached(collection="grids", encoding="npz") +def _grids(name): + url = GRIDS_URL_PATTERN.format(name=name.lower()) + LOG.error("Downloading grids from %s", url) + LOG.warning("Downloading grids from %s", url) + response = requests.get(url) + response.raise_for_status() + return response.content + + +def grids(name): + data = _grids(name) + npz = np.load(BytesIO(data)) + return dict(npz) diff --git a/tests/test_caching.py b/tests/test_caching.py new file mode 100644 index 0000000..83599ba --- /dev/null +++ b/tests/test_caching.py @@ -0,0 +1,113 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import numpy as np + +from anemoi.utils.caching import cached +from anemoi.utils.caching import clean_cache + + +def check(f, data): + """Check that the function f returns the expected values from the data. + The function f is called three times for each value in the data. + The number of actual calls to the function is checked to make sure the cache is used when it should be. + """ + + for i, x in enumerate(data): + assert data.n == i + + res = f(x) + assert type(res) == type(data[x]) # noqa: E721 + assert str(res) == str(data[x]) + assert data.n == i + 1 + + res = f(x) + assert type(res) == type(data[x]) # noqa: E721 + assert str(res) == str(data[x]) + assert data.n == i + 1 + + +class Data(dict): + """Simple class to store data and count the number of calls to the function.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.n = 0 + + +global DATA + +######################################### +# basic test +######################################### +global values_a +values_a = Data(a=1, b=2) + + +@cached(collection="test", expires=0) +def func_a(x): + global values_a + values_a.n += 1 + return values_a[x] + + +def test_cached_basic(*values, **kwargs): + clean_cache("test") + check(func_a, values_a) + + +######################################### +# Test with numpy arrays +######################################### + +global values_c +values_c = Data( + a=dict(A=np.array([1, 2, 3]), B=np.array([4, 5, 6])), + b=dict(A=np.array([7, 8, 9]), B=np.array([10, 11, 12])), +) + + +@cached(collection="test", expires=0, encoding="npz") +def func_c(x): + global values_c + values_c.n += 1 + return values_c[x] + + +def test_cached_npz(*values, **kwargs): + clean_cache("test") + check(func_c, values_c) + + +######################################### +# Test with a various types +global values_d +values_d = Data(a="4", b=5.0, c=dict(d=6), e=[7, 8, 9], f=(10, 11, 12)) + + +@cached(collection="test", expires=0) +def func_d(x): + global values_d + values_d.n += 1 + return values_d[x] + + +def test_cached_various_types(*values, **kwargs): + clean_cache("test") + check(func_d, values_d) + + +######################################### + +if __name__ == "__main__": + for name, obj in list(globals().items()): + if name.startswith("test_") and callable(obj): + print(f"Running {name}...") + obj() diff --git a/tests/test_grids.py b/tests/test_grids.py new file mode 100644 index 0000000..e378c2a --- /dev/null +++ b/tests/test_grids.py @@ -0,0 +1,28 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from anemoi.utils.grids import grids + + +def test_o96(): + x = grids("o96") + assert x["latitudes"].mean() == 0.0 + assert x["longitudes"].mean() == 179.14285714285714 + assert x["latitudes"].shape == (40320,) + assert x["longitudes"].shape == (40320,) + assert x["latitudes"][31415] == -31.324557701757268 + assert x["longitudes"][31415] == 224.32835820895522 + + +if __name__ == "__main__": + for name, obj in list(globals().items()): + if name.startswith("test_") and callable(obj): + print(f"Running {name}...") + obj()