From ae8d22d8f8d5c367ff6eb5227f733f0497668f90 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 6 Nov 2024 16:34:55 +0100 Subject: [PATCH] feat: Add aliases decorator (#40) * Add aliases function for backward compatibility --- CHANGELOG.md | 1 + src/anemoi/utils/compatibility.py | 76 +++++++++++++++++++++++++++++++ tests/test_compatibility.py | 32 +++++++++++++ 3 files changed, 109 insertions(+) create mode 100644 src/anemoi/utils/compatibility.py create mode 100644 tests/test_compatibility.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a91bf67..eb93819 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-utils/compare/0.4.1...HEAD) ### Added +- Add alias decorator [#40](https://github.com/ecmwf/anemoi-utils/pull/40) - Add supporting_arrays to checkpoints - Add factories registry - Optional renaming of subcommands via `command` attribute [#34](https://github.com/ecmwf/anemoi-utils/pull/34) diff --git a/src/anemoi/utils/compatibility.py b/src/anemoi/utils/compatibility.py new file mode 100644 index 0000000..699cdf8 --- /dev/null +++ b/src/anemoi/utils/compatibility.py @@ -0,0 +1,76 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import functools +from typing import Any +from typing import Callable + + +def aliases( + aliases: dict[str, str | list[str]] | None = None, **kwargs: str | list[str] +) -> Callable[[Callable], Callable]: + """Alias keyword arguments in a function call. + + Allows for dynamically renaming keyword arguments in a function call. + + Parameters + ---------- + aliases : dict[str, str | list[str]] | None, optional + Key, value pair of aliases, with keys being the true name, and value being a str or list of aliases, + by default None + **kwargs : str | list[str] + Kwargs form of aliases + + Returns + ------- + Callable + Decorator function that renames keyword arguments in a function call. + + Raises + ------ + ValueError + If the aliasing would result in duplicate keys. + + Examples + -------- + ```python + @aliases(a="b", c=["d", "e"]) + def func(a, c): + return a, c + + func(a=1, c=2) # (1, 2) + func(b=1, d=2) # (1, 2) + ``` + + """ + + if aliases is None: + aliases = {} + aliases.update(kwargs) + + aliases = {v: k for k, vs in aliases.items() for v in (vs if isinstance(vs, list) else [vs])} + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs) -> Any: + keys = kwargs.keys() + for k in set(keys).intersection(set(aliases.keys())): + if aliases[k] in keys: + raise ValueError( + f"When aliasing {k} with {aliases[k]} duplicate keys were present. Cannot include both." + ) + kwargs[aliases[k]] = kwargs.pop(k) + + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/tests/test_compatibility.py b/tests/test_compatibility.py new file mode 100644 index 0000000..9f270f6 --- /dev/null +++ b/tests/test_compatibility.py @@ -0,0 +1,32 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest + +from anemoi.utils.compatibility import aliases + + +def test_aliases() -> None: + + @aliases(a="b", c=["d", "e"]) + def func(a, c): + return a, c + + assert func(a=1, c=2) == (1, 2) + assert func(a=1, d=2) == (1, 2) + assert func(b=1, d=2) == (1, 2) + + +def test_duplicate_values() -> None: + @aliases(a="b", c=["d", "e"]) + def func(a, c): + return a, c + + with pytest.raises(ValueError): + func(a=1, b=2)