diff --git a/src/anemoi/utils/hindcasts.py b/src/anemoi/utils/hindcasts.py index 2c74efc..fff624f 100644 --- a/src/anemoi/utils/hindcasts.py +++ b/src/anemoi/utils/hindcasts.py @@ -27,16 +27,16 @@ def __init__(self, reference_dates, years=20): self.reference_dates = reference_dates - if isinstance(years, list): - self.years = years - else: - self.years = range(1, years + 1) + assert isinstance(years, int), f"years must be an integer, got {years}" + assert years > 0, f"years must be greater than 0, got {years}" + self.years = years def __iter__(self): for reference_date in self.reference_dates: - for year in self.years: - if reference_date.month == 2 and reference_date.day == 29: - date = datetime.datetime(reference_date.year - year, 2, 28) - else: - date = datetime.datetime(reference_date.year - year, reference_date.month, reference_date.day) + year, month, day = reference_date.year, reference_date.month, reference_date.day + if (month, day) == (2, 29): + day = 28 + + for i in range(1, self.years + 1): + date = datetime.datetime(year - i, month, day) yield (date, reference_date) diff --git a/src/anemoi/utils/registry.py b/src/anemoi/utils/registry.py index 03ee6ed..5ab88e0 100644 --- a/src/anemoi/utils/registry.py +++ b/src/anemoi/utils/registry.py @@ -30,6 +30,9 @@ def __call__(self, factory): return factory +_BY_KIND = {} + + class Registry: """A registry of factories""" @@ -39,6 +42,11 @@ def __init__(self, package, key="_type"): self.registered = {} self.kind = package.split(".")[-1] self.key = key + _BY_KIND[self.kind] = self + + @classmethod + def lookup_kind(cls, kind: str): + return _BY_KIND.get(kind) def register(self, name: str, factory: callable = None): @@ -47,6 +55,9 @@ def register(self, name: str, factory: callable = None): self.registered[name] = factory + # def registered(self, name: str): + # return name in self.registered + def _load(self, file): name, _ = os.path.splitext(file) try: @@ -54,7 +65,9 @@ def _load(self, file): except Exception: LOG.warning(f"Error loading filter '{self.package}.{name}'", exc_info=True) - def lookup(self, name: str) -> callable: + def lookup(self, name: str, *, return_none=False) -> callable: + + # print('✅✅✅✅✅✅✅✅✅✅✅✅✅', name, self.registered) if name in self.registered: return self.registered[name] @@ -87,8 +100,12 @@ def lookup(self, name: str) -> callable: self.registered[name] = entry_point.load() if name not in self.registered: + if return_none: + return None + for e in self.registered: LOG.info(f"Registered: {e}") + raise ValueError(f"Cannot load '{name}' from {self.package}") return self.registered[name] @@ -97,8 +114,8 @@ def create(self, name: str, *args, **kwargs): factory = self.lookup(name) return factory(*args, **kwargs) - def __call__(self, name: str, *args, **kwargs): - return self.create(name, *args, **kwargs) + # def __call__(self, name: str, *args, **kwargs): + # return self.create(name, *args, **kwargs) def from_config(self, config, *args, **kwargs): if isinstance(config, str): @@ -125,5 +142,5 @@ def from_config(self, config, *args, **kwargs): return self.create(key, *args, value, **kwargs) raise ValueError( - f"Entry '{config}' must either be a string, a dictionray with a single entry, or a dictionary with a '{self.key}' key" + f"Entry '{config}' must either be a string, a dictionary with a single entry, or a dictionary with a '{self.key}' key" ) diff --git a/tests/test_dates.py b/tests/test_dates.py index a285550..d136624 100644 --- a/tests/test_dates.py +++ b/tests/test_dates.py @@ -83,36 +83,6 @@ def test_date_hindcast_1(): assert len(list(d)) == 60 -def test_date_hindcast_2(): - d = _( - """ - - name: hindcast - reference_dates: - start: 2023-01-01 - end: 2023-01-03 - frequency: 24 - years: [2018, 2019, 2020, 2021] - """ - ) - assert len(list(d)) == 12 - - -def test_date_hindcast_3(): - d = _( - """ - - name: hindcast - reference_dates: - start: 2022-12-25 00:00:00 - end: 2022-12-31 12:00:00 - frequency: 12h - day_of_week: tuesday - years: [2018, 2019, 2020, 2021] - """ - ) - print(list(d)) - assert len(list(d)) == 8 - - if __name__ == "__main__": for name, obj in list(globals().items()): if name.startswith("test_") and callable(obj):