Skip to content

Commit

Permalink
Merge pull request #49 from ecmwf/bugfix/hindcasts
Browse files Browse the repository at this point in the history
fix hindcasts
  • Loading branch information
floriankrb authored Nov 14, 2024
2 parents 03c1583 + 75f38ce commit 6c7bb41
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 43 deletions.
18 changes: 9 additions & 9 deletions src/anemoi/utils/hindcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def __init__(self, reference_dates, years=20):

self.reference_dates = reference_dates

if isinstance(years, list):
self.years = years
else:
self.years = range(1, years + 1)
assert isinstance(years, int), f"years must be an integer, got {years}"
assert years > 0, f"years must be greater than 0, got {years}"
self.years = years

def __iter__(self):
for reference_date in self.reference_dates:
for year in self.years:
if reference_date.month == 2 and reference_date.day == 29:
date = datetime.datetime(reference_date.year - year, 2, 28)
else:
date = datetime.datetime(reference_date.year - year, reference_date.month, reference_date.day)
year, month, day = reference_date.year, reference_date.month, reference_date.day
if (month, day) == (2, 29):
day = 28

for i in range(1, self.years + 1):
date = datetime.datetime(year - i, month, day)
yield (date, reference_date)
25 changes: 21 additions & 4 deletions src/anemoi/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def __call__(self, factory):
return factory


_BY_KIND = {}


class Registry:
"""A registry of factories"""

Expand All @@ -39,6 +42,11 @@ def __init__(self, package, key="_type"):
self.registered = {}
self.kind = package.split(".")[-1]
self.key = key
_BY_KIND[self.kind] = self

@classmethod
def lookup_kind(cls, kind: str):
return _BY_KIND.get(kind)

def register(self, name: str, factory: callable = None):

Expand All @@ -47,14 +55,19 @@ def register(self, name: str, factory: callable = None):

self.registered[name] = factory

# def registered(self, name: str):
# return name in self.registered

def _load(self, file):
name, _ = os.path.splitext(file)
try:
importlib.import_module(f".{name}", package=self.package)
except Exception:
LOG.warning(f"Error loading filter '{self.package}.{name}'", exc_info=True)

def lookup(self, name: str) -> callable:
def lookup(self, name: str, *, return_none=False) -> callable:

# print('✅✅✅✅✅✅✅✅✅✅✅✅✅', name, self.registered)
if name in self.registered:
return self.registered[name]

Expand Down Expand Up @@ -87,8 +100,12 @@ def lookup(self, name: str) -> callable:
self.registered[name] = entry_point.load()

if name not in self.registered:
if return_none:
return None

for e in self.registered:
LOG.info(f"Registered: {e}")

raise ValueError(f"Cannot load '{name}' from {self.package}")

return self.registered[name]
Expand All @@ -97,8 +114,8 @@ def create(self, name: str, *args, **kwargs):
factory = self.lookup(name)
return factory(*args, **kwargs)

def __call__(self, name: str, *args, **kwargs):
return self.create(name, *args, **kwargs)
# def __call__(self, name: str, *args, **kwargs):
# return self.create(name, *args, **kwargs)

def from_config(self, config, *args, **kwargs):
if isinstance(config, str):
Expand All @@ -125,5 +142,5 @@ def from_config(self, config, *args, **kwargs):
return self.create(key, *args, value, **kwargs)

raise ValueError(
f"Entry '{config}' must either be a string, a dictionray with a single entry, or a dictionary with a '{self.key}' key"
f"Entry '{config}' must either be a string, a dictionary with a single entry, or a dictionary with a '{self.key}' key"
)
30 changes: 0 additions & 30 deletions tests/test_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,36 +83,6 @@ def test_date_hindcast_1():
assert len(list(d)) == 60


def test_date_hindcast_2():
d = _(
"""
- name: hindcast
reference_dates:
start: 2023-01-01
end: 2023-01-03
frequency: 24
years: [2018, 2019, 2020, 2021]
"""
)
assert len(list(d)) == 12


def test_date_hindcast_3():
d = _(
"""
- name: hindcast
reference_dates:
start: 2022-12-25 00:00:00
end: 2022-12-31 12:00:00
frequency: 12h
day_of_week: tuesday
years: [2018, 2019, 2020, 2021]
"""
)
print(list(d))
assert len(list(d)) == 8


if __name__ == "__main__":
for name, obj in list(globals().items()):
if name.startswith("test_") and callable(obj):
Expand Down

0 comments on commit 6c7bb41

Please sign in to comment.