-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: added-anemoi-utils-grids-and-tests
- Loading branch information
1 parent
cae49ae
commit a736cec
Showing
4 changed files
with
230 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# (C) Copyright 2025 Anemoi contributors. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
|
||
"""Utilities for working with grids. | ||
""" | ||
|
||
import logging | ||
from io import BytesIO | ||
|
||
import numpy as np | ||
import requests | ||
|
||
from .caching import cached | ||
|
||
LOG = logging.getLogger(__name__) | ||
|
||
|
||
GRIDS_URL_PATTERN = "https://get.ecmwf.int/repository/anemoi/grids/grid-{name}.npz" | ||
|
||
|
||
@cached(collection="grids", encoding="npz") | ||
def _grids(name): | ||
url = GRIDS_URL_PATTERN.format(name=name.lower()) | ||
LOG.error("Downloading grids from %s", url) | ||
LOG.warning("Downloading grids from %s", url) | ||
response = requests.get(url) | ||
response.raise_for_status() | ||
return response.content | ||
|
||
|
||
def grids(name): | ||
data = _grids(name) | ||
npz = np.load(BytesIO(data)) | ||
return dict(npz) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# (C) Copyright 2025 Anemoi contributors. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
|
||
import numpy as np | ||
|
||
from anemoi.utils.caching import cached | ||
from anemoi.utils.caching import clean_cache | ||
|
||
|
||
def check(f, data): | ||
"""Check that the function f returns the expected values from the data. | ||
The function f is called three times for each value in the data. | ||
The number of actual calls to the function is checked to make sure the cache is used when it should be. | ||
""" | ||
|
||
for i, x in enumerate(data): | ||
assert data.n == i | ||
|
||
res = f(x) | ||
assert type(res) == type(data[x]) # noqa: E721 | ||
assert str(res) == str(data[x]) | ||
assert data.n == i + 1 | ||
|
||
res = f(x) | ||
assert type(res) == type(data[x]) # noqa: E721 | ||
assert str(res) == str(data[x]) | ||
assert data.n == i + 1 | ||
|
||
|
||
class Data(dict): | ||
"""Simple class to store data and count the number of calls to the function.""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.n = 0 | ||
|
||
|
||
global DATA | ||
|
||
######################################### | ||
# basic test | ||
######################################### | ||
global values_a | ||
values_a = Data(a=1, b=2) | ||
|
||
|
||
@cached(collection="test", expires=0) | ||
def func_a(x): | ||
global values_a | ||
values_a.n += 1 | ||
return values_a[x] | ||
|
||
|
||
def test_cached_basic(*values, **kwargs): | ||
clean_cache("test") | ||
check(func_a, values_a) | ||
|
||
|
||
######################################### | ||
# Test with numpy arrays | ||
######################################### | ||
|
||
global values_c | ||
values_c = Data( | ||
a=dict(A=np.array([1, 2, 3]), B=np.array([4, 5, 6])), | ||
b=dict(A=np.array([7, 8, 9]), B=np.array([10, 11, 12])), | ||
) | ||
|
||
|
||
@cached(collection="test", expires=0, encoding="npz") | ||
def func_c(x): | ||
global values_c | ||
values_c.n += 1 | ||
return values_c[x] | ||
|
||
|
||
def test_cached_npz(*values, **kwargs): | ||
clean_cache("test") | ||
check(func_c, values_c) | ||
|
||
|
||
######################################### | ||
# Test with a various types | ||
global values_d | ||
values_d = Data(a="4", b=5.0, c=dict(d=6), e=[7, 8, 9], f=(10, 11, 12)) | ||
|
||
|
||
@cached(collection="test", expires=0) | ||
def func_d(x): | ||
global values_d | ||
values_d.n += 1 | ||
return values_d[x] | ||
|
||
|
||
def test_cached_various_types(*values, **kwargs): | ||
clean_cache("test") | ||
check(func_d, values_d) | ||
|
||
|
||
######################################### | ||
|
||
if __name__ == "__main__": | ||
for name, obj in list(globals().items()): | ||
if name.startswith("test_") and callable(obj): | ||
print(f"Running {name}...") | ||
obj() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# (C) Copyright 2025 Anemoi contributors. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
|
||
from anemoi.utils.grids import grids | ||
|
||
|
||
def test_o96(): | ||
x = grids("o96") | ||
assert x["latitudes"].mean() == 0.0 | ||
assert x["longitudes"].mean() == 179.14285714285714 | ||
assert x["latitudes"].shape == (40320,) | ||
assert x["longitudes"].shape == (40320,) | ||
assert x["latitudes"][31415] == -31.324557701757268 | ||
assert x["longitudes"][31415] == 224.32835820895522 | ||
|
||
|
||
if __name__ == "__main__": | ||
for name, obj in list(globals().items()): | ||
if name.startswith("test_") and callable(obj): | ||
print(f"Running {name}...") | ||
obj() |