Skip to content

Commit

Permalink
Feature/registry (#35)
Browse files Browse the repository at this point in the history
* add supporting_arrays  to checkpoints

---------

Co-authored-by: Florian Pinault <[email protected]>
  • Loading branch information
b8raoult and floriankrb authored Nov 1, 2024
1 parent ba4279f commit 4c2329a
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ Keep it human-readable, your future self will thank you!
## [Unreleased](https://github.com/ecmwf/anemoi-utils/compare/0.4.1...HEAD)

### Added
- Add supporting_arrays to checkpoints
- Add factories registry
- Optional renaming of subcommands via `command` attribute [#34](https://github.com/ecmwf/anemoi-utils/pull/34)


## [0.4.1](https://github.com/ecmwf/anemoi-utils/compare/0.4.0...0.4.1) - 2024-10-23

## Fixed
Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@

project = "Anemoi Utils"

author = "ECMWF"
author = "Anemoi contributors"

year = datetime.datetime.now().year
if year == 2024:
years = "2024"
else:
years = "2024-%s" % (year,)

copyright = "%s, ECMWF" % (years,)
copyright = "%s, Anemoi contributors" % (years,)

try:
from anemoi.utils._version import __version__
Expand Down
4 changes: 3 additions & 1 deletion src/anemoi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
Expand Down
83 changes: 74 additions & 9 deletions src/anemoi/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
DEFAULT_FOLDER = "anemoi-metadata"


def has_metadata(path: str, name: str = DEFAULT_NAME) -> bool:
def has_metadata(path: str, *, name: str = DEFAULT_NAME) -> bool:
"""Check if a checkpoint file has a metadata file
Parameters
Expand All @@ -49,13 +49,26 @@ def has_metadata(path: str, name: str = DEFAULT_NAME) -> bool:
return False


def load_metadata(path: str, name: str = DEFAULT_NAME) -> dict:
def metadata_root(path: str, *, name: str = DEFAULT_NAME) -> bool:

with zipfile.ZipFile(path, "r") as f:
for b in f.namelist():
if os.path.basename(b) == name:
return os.path.dirname(b)
raise ValueError(f"Could not find '{name}' in {path}.")


def load_metadata(path: str, *, supporting_arrays=False, name: str = DEFAULT_NAME) -> dict:
"""Load metadata from a checkpoint file
Parameters
----------
path : str
The path to the checkpoint file
supporting_arrays: bool, optional
If True, the function will return a dictionary with the supporting arrays
name : str, optional
The name of the metadata file in the zip archive
Expand All @@ -79,12 +92,29 @@ def load_metadata(path: str, name: str = DEFAULT_NAME) -> dict:

if metadata is not None:
with zipfile.ZipFile(path, "r") as f:
return json.load(f.open(metadata, "r"))
metadata = json.load(f.open(metadata, "r"))
if supporting_arrays:
metadata["supporting_arrays"] = load_supporting_arrays(f, metadata.get("supporting_arrays", {}))
return metadata, supporting_arrays

return metadata
else:
raise ValueError(f"Could not find '{name}' in {path}.")


def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> None:
def load_supporting_arrays(zipf, entries) -> dict:
import numpy as np

supporting_arrays = {}
for key, entry in entries.items():
supporting_arrays[key] = np.frombuffer(
zipf.read(entry["path"]),
dtype=entry["dtype"],
).reshape(entry["shape"])
return supporting_arrays


def save_metadata(path, metadata, *, supporting_arrays=None, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> None:
"""Save metadata to a checkpoint file
Parameters
Expand All @@ -93,6 +123,8 @@ def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> N
The path to the checkpoint file
metadata : JSON
A JSON serializable object
supporting_arrays: dict, optional
A dictionary of supporting NumPy arrays
name : str, optional
The name of the metadata file in the zip archive
folder : str, optional
Expand All @@ -118,19 +150,41 @@ def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> N

directory = list(directories)[0]

LOG.info("Adding extra information to checkpoint %s", path)
LOG.info("Saving metadata to %s/%s/%s", directory, folder, name)

metadata = metadata.copy()
if supporting_arrays is not None:
metadata["supporting_arrays_paths"] = {
key: dict(path=f"{directory}/{folder}/{key}.numpy", shape=value.shape, dtype=str(value.dtype))
for key, value in supporting_arrays.items()
}
else:
metadata["supporting_arrays_paths"] = {}

zipf.writestr(
f"{directory}/{folder}/{name}",
json.dumps(metadata),
)

for name, entry in metadata["supporting_arrays_paths"].items():
value = supporting_arrays[name]
LOG.info(
"Saving supporting array `%s` to %s (shape=%s, dtype=%s)",
name,
entry["path"],
entry["shape"],
entry["dtype"],
)
zipf.writestr(entry["path"], value.tobytes())


def _edit_metadata(path, name, callback):
def _edit_metadata(path, name, callback, supporting_arrays=None):
new_path = f"{path}.anemoi-edit-{time.time()}-{os.getpid()}.tmp"

found = False

directory = None
with TemporaryDirectory() as temp_dir:
zipfile.ZipFile(path, "r").extractall(temp_dir)
total = 0
Expand All @@ -141,10 +195,21 @@ def _edit_metadata(path, name, callback):
if f == name:
found = True
callback(full)
directory = os.path.dirname(full)

if not found:
raise ValueError(f"Could not find '{name}' in {path}")

if supporting_arrays is not None:

for key, entry in supporting_arrays.items():
value = entry.tobytes()
fname = os.path.join(directory, f"{key}.numpy")
os.makedirs(os.path.dirname(fname), exist_ok=True)
with open(fname, "wb") as f:
f.write(value)
total += 1

with zipfile.ZipFile(new_path, "w", zipfile.ZIP_DEFLATED) as zipf:
with tqdm.tqdm(total=total, desc="Rebuilding checkpoint") as pbar:
for root, dirs, files in os.walk(temp_dir):
Expand All @@ -158,7 +223,7 @@ def _edit_metadata(path, name, callback):
LOG.info("Updated metadata in %s", path)


def replace_metadata(path, metadata, name=DEFAULT_NAME):
def replace_metadata(path, metadata, supporting_arrays=None, *, name=DEFAULT_NAME):

if not isinstance(metadata, dict):
raise ValueError(f"metadata must be a dict, got {type(metadata)}")
Expand All @@ -170,14 +235,14 @@ def callback(full):
with open(full, "w") as f:
json.dump(metadata, f)

_edit_metadata(path, name, callback)
return _edit_metadata(path, name, callback, supporting_arrays)


def remove_metadata(path, name=DEFAULT_NAME):
def remove_metadata(path, *, name=DEFAULT_NAME):

LOG.info("Removing metadata '%s' from %s", name, path)

def callback(full):
os.remove(full)

_edit_metadata(path, name, callback)
return _edit_metadata(path, name, callback)
5 changes: 3 additions & 2 deletions src/anemoi/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def check_config_mode(name="settings.toml", secrets_name=None, secrets=None) ->
CHECKED[name] = True


def find(metadata, what, result=None):
def find(metadata, what, result=None, *, select: callable = None):
if result is None:
result = []

Expand All @@ -369,7 +369,8 @@ def find(metadata, what, result=None):

if isinstance(metadata, dict):
if what in metadata:
result.append(metadata[what])
if select is None or select(metadata[what]):
result.append(metadata[what])

for k, v in metadata.items():
find(v, what, result)
Expand Down
98 changes: 98 additions & 0 deletions src/anemoi/utils/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import importlib
import logging
import os
import sys

import entrypoints

LOG = logging.getLogger(__name__)


class Wrapper:
"""A wrapper for the registry"""

def __init__(self, name, registry):
self.name = name
self.registry = registry

def __call__(self, factory):
self.registry.register(self.name, factory)
return factory


class Registry:
"""A registry of factories"""

def __init__(self, package):

self.package = package
self.registered = {}
self.kind = package.split(".")[-1]

def register(self, name: str, factory: callable = None):

if factory is None:
return Wrapper(name, self)

self.registered[name] = factory

def _load(self, file):
name, _ = os.path.splitext(file)
try:
importlib.import_module(f".{name}", package=self.package)
except Exception:
LOG.warning(f"Error loading filter '{self.package}.{name}'", exc_info=True)

def lookup(self, name: str) -> callable:
if name in self.registered:
return self.registered[name]

directory = sys.modules[self.package].__path__[0]

for file in os.listdir(directory):

if file[0] == ".":
continue

if file == "__init__.py":
continue

full = os.path.join(directory, file)
if os.path.isdir(full):
if os.path.exists(os.path.join(full, "__init__.py")):
self._load(file)
continue

if file.endswith(".py"):
self._load(file)

entrypoint_group = f"anemoi.{self.kind}"
for entry_point in entrypoints.get_group_all(entrypoint_group):
if entry_point.name == name:
if name in self.registered:
LOG.warning(
f"Overwriting builtin '{name}' from {self.package} with plugin '{entry_point.module_name}'"
)
self.registered[name] = entry_point.load()

if name not in self.registered:
raise ValueError(f"Cannot load '{name}' from {self.package}")

return self.registered[name]

def create(self, name: str, *args, **kwargs):
factory = self.lookup(name)
return factory(*args, **kwargs)

def __call__(self, name: str, *args, **kwargs):
return self.create(name, *args, **kwargs)

0 comments on commit 4c2329a

Please sign in to comment.