Skip to content

Commit

Permalink
Add utility to create a deprecated alias.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622824027
  • Loading branch information
mtthss authored and ChexDev committed Apr 8, 2024
1 parent 8b889ff commit 0ae3287
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 0 deletions.
2 changes: 2 additions & 0 deletions chex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
from chex._src.variants import params_product
from chex._src.variants import TestCase
from chex._src.variants import variants
from chex._src.warnings import create_deprecated_function_alias
from chex._src.warnings import warn_deprecated_function
from chex._src.warnings import warn_keyword_args_only_in_future
from chex._src.warnings import warn_only_n_pos_args_in_future
Expand Down Expand Up @@ -168,6 +169,7 @@
"chexify",
"ChexVariantType",
"clear_trace_counter",
"create_deprecated_function_alias",
"dataclass",
"Device",
"Dimensions",
Expand Down
26 changes: 26 additions & 0 deletions chex/_src/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,29 @@ def new_fun(*args, **kwargs):
stacklevel=2)
return fun(*args, **kwargs)
return new_fun


def create_deprecated_function_alias(fun, new_name, deprecated_alias):
"""Create a deprecated alias for a function.
Example usage:
>>> g = create_deprecated_function_alias(f, 'path.f', 'path.g')
Args:
fun: the deprecated function.
new_name: the new name to use (you may include the path for clarity).
deprecated_alias: the old name (you may include the path for clarity).
Returns:
the wrapped function.
"""

@functools.wraps(fun)
def new_fun(*args, **kwargs):
warnings.warn(
f'The function {deprecated_alias} is deprecated, '
f'please use {new_name} instead.',
category=DeprecationWarning,
stacklevel=2)
return fun(*args, **kwargs)
return new_fun
9 changes: 9 additions & 0 deletions chex/_src/warnings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def g(a, b, c):
return a + b + c


def h1(a, b, c):
return a + b + c
h2 = warnings.create_deprecated_function_alias(h1, 'path.h2', 'path.h1')


class WarningsTest(absltest.TestCase):

def test_warn_only_n_pos_args_in_future(self):
Expand All @@ -43,6 +48,10 @@ def test_warn_deprecated_function(self):
with self.assertWarns(Warning):
g(1, 2, 3)

def test_create_deprecated_function_alias(self):
with self.assertWarns(Warning):
h2(1, 2, 3)


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ Warnings

.. currentmodule:: chex

.. autofunction:: create_deprecated_function_alias
.. autofunction:: warn_deprecated_function
.. autofunction:: warn_keyword_args_only_in_future
.. autofunction:: warn_only_n_pos_args_in_future
Expand Down

0 comments on commit 0ae3287

Please sign in to comment.