Skip to content

Commit

Permalink
feat: added-anemoi-utils-grids-and-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Jan 14, 2025
1 parent cae49ae commit a736cec
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 11 deletions.
58 changes: 47 additions & 11 deletions src/anemoi/utils/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,47 @@
import time
from threading import Lock

import numpy as np

LOCK = Lock()
CACHE = {}


def cache(key, proc, collection="default", expires=None):
def _json_save(path, data):
with open(path, "w") as f:
json.dump(data, f)


def _json_load(path):
with open(path, "r") as f:
return json.load(f)


def _npz_save(path, data):
return np.savez(path, **data)


def _npz_load(path):
return np.load(path, allow_pickle=True)


def _get_cache_path(collection):
return os.path.join(os.path.expanduser("~"), ".cache", "anemoi", collection)


def clean_cache(collection="default"):
path = _get_cache_path(collection)
if not os.path.exists(path):
return
for filename in os.listdir(path):
os.remove(os.path.join(path, filename))


def cache(key, proc, collection="default", expires=None, encoding="json"):
load, save, ext = dict(
json=(_json_load, _json_save, ""),
npz=(_npz_load, _npz_save, ".npz"),
)[encoding]

key = json.dumps(key, sort_keys=True)
m = hashlib.md5()
Expand All @@ -28,24 +64,22 @@ def cache(key, proc, collection="default", expires=None):
if m in CACHE:
return CACHE[m]

path = os.path.join(os.path.expanduser("~"), ".cache", "anemoi", collection)
path = _get_cache_path(collection)
os.makedirs(path, exist_ok=True)

filename = os.path.join(path, m)
filename = os.path.join(path, m) + ext
if os.path.exists(filename):
with open(filename, "r") as f:
data = json.load(f)
if expires is None or data["expires"] > time.time():
if data["key"] == key:
return data["value"]
data = load(filename)
if expires is None or data["expires"] > time.time():
if data["key"] == key:
return data["value"]

value = proc()
data = {"key": key, "value": value}
if expires is not None:
data["expires"] = time.time() + expires

with open(filename, "w") as f:
json.dump(data, f)
save(filename, data)

CACHE[m] = value
return value
Expand All @@ -54,9 +88,10 @@ def cache(key, proc, collection="default", expires=None):
class cached:
"""Decorator to cache the result of a function."""

def __init__(self, collection="default", expires=None):
def __init__(self, collection="default", expires=None, encoding="json"):
self.collection = collection
self.expires = expires
self.encoding = encoding

def __call__(self, func):

Expand All @@ -69,6 +104,7 @@ def wrapped(*args, **kwargs):
lambda: func(*args, **kwargs),
self.collection,
self.expires,
self.encoding,
)

return wrapped
42 changes: 42 additions & 0 deletions src/anemoi/utils/grids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# (C) Copyright 2025 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


"""Utilities for working with grids.
"""

import logging
from io import BytesIO

import numpy as np
import requests

from .caching import cached

LOG = logging.getLogger(__name__)


GRIDS_URL_PATTERN = "https://get.ecmwf.int/repository/anemoi/grids/grid-{name}.npz"


@cached(collection="grids", encoding="npz")
def _grids(name):
url = GRIDS_URL_PATTERN.format(name=name.lower())
LOG.error("Downloading grids from %s", url)
LOG.warning("Downloading grids from %s", url)
response = requests.get(url)
response.raise_for_status()
return response.content


def grids(name):
data = _grids(name)
npz = np.load(BytesIO(data))
return dict(npz)
113 changes: 113 additions & 0 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# (C) Copyright 2025 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import numpy as np

from anemoi.utils.caching import cached
from anemoi.utils.caching import clean_cache


def check(f, data):
"""Check that the function f returns the expected values from the data.
The function f is called three times for each value in the data.
The number of actual calls to the function is checked to make sure the cache is used when it should be.
"""

for i, x in enumerate(data):
assert data.n == i

res = f(x)
assert type(res) == type(data[x]) # noqa: E721
assert str(res) == str(data[x])
assert data.n == i + 1

res = f(x)
assert type(res) == type(data[x]) # noqa: E721
assert str(res) == str(data[x])
assert data.n == i + 1


class Data(dict):
"""Simple class to store data and count the number of calls to the function."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.n = 0


global DATA

#########################################
# basic test
#########################################
global values_a
values_a = Data(a=1, b=2)


@cached(collection="test", expires=0)
def func_a(x):
global values_a
values_a.n += 1
return values_a[x]


def test_cached_basic(*values, **kwargs):
clean_cache("test")
check(func_a, values_a)


#########################################
# Test with numpy arrays
#########################################

global values_c
values_c = Data(
a=dict(A=np.array([1, 2, 3]), B=np.array([4, 5, 6])),
b=dict(A=np.array([7, 8, 9]), B=np.array([10, 11, 12])),
)


@cached(collection="test", expires=0, encoding="npz")
def func_c(x):
global values_c
values_c.n += 1
return values_c[x]


def test_cached_npz(*values, **kwargs):
clean_cache("test")
check(func_c, values_c)


#########################################
# Test with a various types
global values_d
values_d = Data(a="4", b=5.0, c=dict(d=6), e=[7, 8, 9], f=(10, 11, 12))


@cached(collection="test", expires=0)
def func_d(x):
global values_d
values_d.n += 1
return values_d[x]


def test_cached_various_types(*values, **kwargs):
clean_cache("test")
check(func_d, values_d)


#########################################

if __name__ == "__main__":
for name, obj in list(globals().items()):
if name.startswith("test_") and callable(obj):
print(f"Running {name}...")
obj()
28 changes: 28 additions & 0 deletions tests/test_grids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# (C) Copyright 2025 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


from anemoi.utils.grids import grids


def test_o96():
x = grids("o96")
assert x["latitudes"].mean() == 0.0
assert x["longitudes"].mean() == 179.14285714285714
assert x["latitudes"].shape == (40320,)
assert x["longitudes"].shape == (40320,)
assert x["latitudes"][31415] == -31.324557701757268
assert x["longitudes"][31415] == 224.32835820895522


if __name__ == "__main__":
for name, obj in list(globals().items()):
if name.startswith("test_") and callable(obj):
print(f"Running {name}...")
obj()

0 comments on commit a736cec

Please sign in to comment.