From 4c2329a251dbcf48305fb6aa55c2f84a8f5a3394 Mon Sep 17 00:00:00 2001 From: b8raoult <53792887+b8raoult@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:21:41 +0000 Subject: [PATCH] Feature/registry (#35) * add supporting_arrays to checkpoints --------- Co-authored-by: Florian Pinault --- CHANGELOG.md | 3 + docs/conf.py | 4 +- src/anemoi/utils/__init__.py | 4 +- src/anemoi/utils/checkpoints.py | 83 +++++++++++++++++++++++++--- src/anemoi/utils/config.py | 5 +- src/anemoi/utils/registry.py | 98 +++++++++++++++++++++++++++++++++ 6 files changed, 183 insertions(+), 14 deletions(-) create mode 100644 src/anemoi/utils/registry.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 72738a8..dd13b14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,11 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-utils/compare/0.4.1...HEAD) ### Added +- Add supporting_arrays to checkpoints +- Add factories registry - Optional renaming of subcommands via `command` attribute [#34](https://github.com/ecmwf/anemoi-utils/pull/34) + ## [0.4.1](https://github.com/ecmwf/anemoi-utils/compare/0.4.0...0.4.1) - 2024-10-23 ## Fixed diff --git a/docs/conf.py b/docs/conf.py index 7760336..5d812ac 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,7 +29,7 @@ project = "Anemoi Utils" -author = "ECMWF" +author = "Anemoi contributors" year = datetime.datetime.now().year if year == 2024: @@ -37,7 +37,7 @@ else: years = "2024-%s" % (year,) -copyright = "%s, ECMWF" % (years,) +copyright = "%s, Anemoi contributors" % (years,) try: from anemoi.utils._version import __version__ diff --git a/src/anemoi/utils/__init__.py b/src/anemoi/utils/__init__.py index 9733be2..7b9efcd 100644 --- a/src/anemoi/utils/__init__.py +++ b/src/anemoi/utils/__init__.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/utils/checkpoints.py b/src/anemoi/utils/checkpoints.py index 085e4d1..83d386d 100644 --- a/src/anemoi/utils/checkpoints.py +++ b/src/anemoi/utils/checkpoints.py @@ -27,7 +27,7 @@ DEFAULT_FOLDER = "anemoi-metadata" -def has_metadata(path: str, name: str = DEFAULT_NAME) -> bool: +def has_metadata(path: str, *, name: str = DEFAULT_NAME) -> bool: """Check if a checkpoint file has a metadata file Parameters @@ -49,13 +49,26 @@ def has_metadata(path: str, name: str = DEFAULT_NAME) -> bool: return False -def load_metadata(path: str, name: str = DEFAULT_NAME) -> dict: +def metadata_root(path: str, *, name: str = DEFAULT_NAME) -> bool: + + with zipfile.ZipFile(path, "r") as f: + for b in f.namelist(): + if os.path.basename(b) == name: + return os.path.dirname(b) + raise ValueError(f"Could not find '{name}' in {path}.") + + +def load_metadata(path: str, *, supporting_arrays=False, name: str = DEFAULT_NAME) -> dict: """Load metadata from a checkpoint file Parameters ---------- path : str The path to the checkpoint file + + supporting_arrays: bool, optional + If True, the function will return a dictionary with the supporting arrays + name : str, optional The name of the metadata file in the zip archive @@ -79,12 +92,29 @@ def load_metadata(path: str, name: str = DEFAULT_NAME) -> dict: if metadata is not None: with zipfile.ZipFile(path, "r") as f: - return json.load(f.open(metadata, "r")) + metadata = json.load(f.open(metadata, "r")) + if supporting_arrays: + metadata["supporting_arrays"] = load_supporting_arrays(f, metadata.get("supporting_arrays", {})) + return metadata, supporting_arrays + + return metadata else: raise ValueError(f"Could not find '{name}' in {path}.") -def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> None: +def load_supporting_arrays(zipf, entries) -> dict: + import numpy as np + + supporting_arrays = {} + for key, entry in entries.items(): + supporting_arrays[key] = np.frombuffer( + zipf.read(entry["path"]), + dtype=entry["dtype"], + ).reshape(entry["shape"]) + return supporting_arrays + + +def save_metadata(path, metadata, *, supporting_arrays=None, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> None: """Save metadata to a checkpoint file Parameters @@ -93,6 +123,8 @@ def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> N The path to the checkpoint file metadata : JSON A JSON serializable object + supporting_arrays: dict, optional + A dictionary of supporting NumPy arrays name : str, optional The name of the metadata file in the zip archive folder : str, optional @@ -118,19 +150,41 @@ def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> N directory = list(directories)[0] + LOG.info("Adding extra information to checkpoint %s", path) LOG.info("Saving metadata to %s/%s/%s", directory, folder, name) + metadata = metadata.copy() + if supporting_arrays is not None: + metadata["supporting_arrays_paths"] = { + key: dict(path=f"{directory}/{folder}/{key}.numpy", shape=value.shape, dtype=str(value.dtype)) + for key, value in supporting_arrays.items() + } + else: + metadata["supporting_arrays_paths"] = {} + zipf.writestr( f"{directory}/{folder}/{name}", json.dumps(metadata), ) + for name, entry in metadata["supporting_arrays_paths"].items(): + value = supporting_arrays[name] + LOG.info( + "Saving supporting array `%s` to %s (shape=%s, dtype=%s)", + name, + entry["path"], + entry["shape"], + entry["dtype"], + ) + zipf.writestr(entry["path"], value.tobytes()) + -def _edit_metadata(path, name, callback): +def _edit_metadata(path, name, callback, supporting_arrays=None): new_path = f"{path}.anemoi-edit-{time.time()}-{os.getpid()}.tmp" found = False + directory = None with TemporaryDirectory() as temp_dir: zipfile.ZipFile(path, "r").extractall(temp_dir) total = 0 @@ -141,10 +195,21 @@ def _edit_metadata(path, name, callback): if f == name: found = True callback(full) + directory = os.path.dirname(full) if not found: raise ValueError(f"Could not find '{name}' in {path}") + if supporting_arrays is not None: + + for key, entry in supporting_arrays.items(): + value = entry.tobytes() + fname = os.path.join(directory, f"{key}.numpy") + os.makedirs(os.path.dirname(fname), exist_ok=True) + with open(fname, "wb") as f: + f.write(value) + total += 1 + with zipfile.ZipFile(new_path, "w", zipfile.ZIP_DEFLATED) as zipf: with tqdm.tqdm(total=total, desc="Rebuilding checkpoint") as pbar: for root, dirs, files in os.walk(temp_dir): @@ -158,7 +223,7 @@ def _edit_metadata(path, name, callback): LOG.info("Updated metadata in %s", path) -def replace_metadata(path, metadata, name=DEFAULT_NAME): +def replace_metadata(path, metadata, supporting_arrays=None, *, name=DEFAULT_NAME): if not isinstance(metadata, dict): raise ValueError(f"metadata must be a dict, got {type(metadata)}") @@ -170,14 +235,14 @@ def callback(full): with open(full, "w") as f: json.dump(metadata, f) - _edit_metadata(path, name, callback) + return _edit_metadata(path, name, callback, supporting_arrays) -def remove_metadata(path, name=DEFAULT_NAME): +def remove_metadata(path, *, name=DEFAULT_NAME): LOG.info("Removing metadata '%s' from %s", name, path) def callback(full): os.remove(full) - _edit_metadata(path, name, callback) + return _edit_metadata(path, name, callback) diff --git a/src/anemoi/utils/config.py b/src/anemoi/utils/config.py index a6a9cb9..3a9406a 100644 --- a/src/anemoi/utils/config.py +++ b/src/anemoi/utils/config.py @@ -358,7 +358,7 @@ def check_config_mode(name="settings.toml", secrets_name=None, secrets=None) -> CHECKED[name] = True -def find(metadata, what, result=None): +def find(metadata, what, result=None, *, select: callable = None): if result is None: result = [] @@ -369,7 +369,8 @@ def find(metadata, what, result=None): if isinstance(metadata, dict): if what in metadata: - result.append(metadata[what]) + if select is None or select(metadata[what]): + result.append(metadata[what]) for k, v in metadata.items(): find(v, what, result) diff --git a/src/anemoi/utils/registry.py b/src/anemoi/utils/registry.py new file mode 100644 index 0000000..9d4bcce --- /dev/null +++ b/src/anemoi/utils/registry.py @@ -0,0 +1,98 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import importlib +import logging +import os +import sys + +import entrypoints + +LOG = logging.getLogger(__name__) + + +class Wrapper: + """A wrapper for the registry""" + + def __init__(self, name, registry): + self.name = name + self.registry = registry + + def __call__(self, factory): + self.registry.register(self.name, factory) + return factory + + +class Registry: + """A registry of factories""" + + def __init__(self, package): + + self.package = package + self.registered = {} + self.kind = package.split(".")[-1] + + def register(self, name: str, factory: callable = None): + + if factory is None: + return Wrapper(name, self) + + self.registered[name] = factory + + def _load(self, file): + name, _ = os.path.splitext(file) + try: + importlib.import_module(f".{name}", package=self.package) + except Exception: + LOG.warning(f"Error loading filter '{self.package}.{name}'", exc_info=True) + + def lookup(self, name: str) -> callable: + if name in self.registered: + return self.registered[name] + + directory = sys.modules[self.package].__path__[0] + + for file in os.listdir(directory): + + if file[0] == ".": + continue + + if file == "__init__.py": + continue + + full = os.path.join(directory, file) + if os.path.isdir(full): + if os.path.exists(os.path.join(full, "__init__.py")): + self._load(file) + continue + + if file.endswith(".py"): + self._load(file) + + entrypoint_group = f"anemoi.{self.kind}" + for entry_point in entrypoints.get_group_all(entrypoint_group): + if entry_point.name == name: + if name in self.registered: + LOG.warning( + f"Overwriting builtin '{name}' from {self.package} with plugin '{entry_point.module_name}'" + ) + self.registered[name] = entry_point.load() + + if name not in self.registered: + raise ValueError(f"Cannot load '{name}' from {self.package}") + + return self.registered[name] + + def create(self, name: str, *args, **kwargs): + factory = self.lookup(name) + return factory(*args, **kwargs) + + def __call__(self, name: str, *args, **kwargs): + return self.create(name, *args, **kwargs)