Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚀 [Feature] Add deprecated Decorator #3161

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@
from .megatron_lm import prepare_model_optimizer_scheduler as megatron_lm_prepare_model_optimizer_scheduler
from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer
from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler
from .deprecation import deprecated
from .memory import find_executable_batch_size, release_memory
from .other import (
check_os_kernel,
Expand Down
80 changes: 80 additions & 0 deletions src/accelerate/utils/deprecation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import textwrap
import warnings
from typing import Callable, TypeVar

from typing_extensions import ParamSpec


_T = TypeVar("_T")
_P = ParamSpec("_P")


def deprecated(since: str, removed_in: str, instructions: str) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""Marks functions as deprecated.

It will result in a warning when the function is called and a note in the docstring.

Args:
since (`str`):
The version when the function was first deprecated.
removed_in (`str`):
The version when the function will be removed.
instructions (`str`):
The action users should take.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a nit: I'd find it more intuitive if the instructions was renamed instruction and if the "Please " was not added in front, as I would assume as a caller that I need to pass a full sentence here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in a7a43aa and 2cb4648 Thank you :)


Returns:
`Callable`: A decorator that will mark the function as deprecated.
"""

def decorator(function: Callable[_P, _T]) -> Callable[_P, _T]:
@functools.wraps(function)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
warnings.warn(
f"'{function.__module__}.{function.__name__}' "
f"is deprecated in version {since} and will be "
f"removed in {removed_in}. Please {instructions}.",
category=FutureWarning,
stacklevel=2,
)
return function(*args, **kwargs)

# Add a deprecation note to the docstring.
docstring = function.__doc__ or ""

deprecation_note = textwrap.dedent(
f"""\
.. deprecated:: {since}
Deprecated and will be removed in version {removed_in}. Please {instructions}.
"""
)

# Split docstring at first occurrence of newline
summary_and_body = docstring.split("\n\n", 1)
if len(summary_and_body) > 1:
summary, body = summary_and_body
body = textwrap.dedent(body)
new_docstring_parts = [deprecation_note, "\n\n", summary, body]
else:
summary = summary_and_body[0]
new_docstring_parts = [deprecation_note, "\n\n", summary]

wrapper.__doc__ = "".join(new_docstring_parts)

return wrapper

return decorator
6 changes: 2 additions & 4 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ..state import AcceleratorState
from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from .dataclasses import AutocastKwargs, CustomDtype, DistributedType
from .deprecation import deprecated
from .imports import (
is_mlu_available,
is_mps_available,
Expand Down Expand Up @@ -471,11 +472,8 @@ class FindTiedParametersResult(list):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@deprecated(since="1.0.0rc0", removed_in="1.3.0", instructions="use another method instead")
def values(self):
warnings.warn(
"The 'values' method of FindTiedParametersResult is deprecated and will be removed in Accelerate v1.3.0. ",
FutureWarning,
)
return sum([x[1:] for x in self], [])


Expand Down
57 changes: 57 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import pickle
import tempfile
import textwrap
import unittest
import warnings
from collections import UserDict, namedtuple
Expand Down Expand Up @@ -54,6 +55,7 @@
save,
send_to_device,
)
from accelerate.utils.deprecation import deprecated
from accelerate.utils.operations import is_namedtuple


Expand Down Expand Up @@ -413,3 +415,58 @@ def test_convert_dict_to_env_variables(self):
with self.assertLogs("accelerate.utils.environment", level="WARNING"):
valid_env_items = convert_dict_to_env_variables(env)
assert valid_env_items == ["ACCELERATE_DEBUG_MODE=1\n", "OTHER_ENV=2\n"]

def test_deprecated(self):
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
@deprecated("0.2.0", "0.3.0", "toy instruction")
def long_deprecated_demo(arg1: int, arg2: int) -> tuple:
"""This is a long summary. This is a long summary. This is a long
summary. This is a long summary.

Args:
arg1 (int): Description.
arg2 (int): Description.

Returns:
Description.
"""
return arg1, arg2

with pytest.warns(
FutureWarning, match="deprecated in version 0.2.0 and will be removed in 0.3.0. Please toy instruction."
):
self.assertEqual((1, 2), long_deprecated_demo(1, 2))

long_expected_docstring = textwrap.dedent("""
.. deprecated:: 0.2.0
Deprecated and will be removed in version 0.3.0. Please toy instruction.

This is a long summary. This is a long summary. This is a long
summary. This is a long summary.

Args:
arg1 (int): Description.
arg2 (int): Description.

Returns:
Description.
""")

long_expected_docstring = "".join(long_expected_docstring.split())
long_actual_docstring = "".join(long_deprecated_demo.__doc__.split())

self.assertEqual(long_expected_docstring, long_actual_docstring)

@deprecated("0.2.0", "0.3.0", "toy instruction")
def short_deprecated_demo():
"""Short summary."""

short_expected_docstring = textwrap.dedent("""
.. deprecated:: 0.2.0
Deprecated and will be removed in version 0.3.0. Please toy instruction.

Short summary.
""")
short_expected_docstring = "".join(short_expected_docstring.split())
short_actual_docstring = "".join(short_deprecated_demo.__doc__.split())

self.assertEqual(short_expected_docstring, short_actual_docstring)