From 5432739adb714d34f20c86417e62d49e657a2ec7 Mon Sep 17 00:00:00 2001 From: Fripping <124574028+Fripping@users.noreply.github.com> Date: Fri, 26 Jul 2024 16:03:35 +0800 Subject: [PATCH] [Typing][B-42][BUAA] Add type annotations for `python/paddle/autograd/py_layer.py` (#66328) --- python/paddle/autograd/py_layer.py | 46 +++++++++++++++++++----------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/python/paddle/autograd/py_layer.py b/python/paddle/autograd/py_layer.py index 2843560f4a878c..be14b4b0a7ae2d 100644 --- a/python/paddle/autograd/py_layer.py +++ b/python/paddle/autograd/py_layer.py @@ -12,18 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Sequence, TypeVar + +from typing_extensions import Concatenate + import paddle from paddle.base import core -__all__ = [] +if TYPE_CHECKING: + from paddle import Tensor -def with_metaclass(meta, *bases): - class impl(meta): - def __new__(cls, name, temp_bases, attrs): - return meta(name, bases, attrs) +__all__ = [] + - return type.__new__(impl, "impl", (), {}) +_RetT = TypeVar('_RetT') class PyLayerContext: @@ -52,7 +57,12 @@ class PyLayerContext: ... return grad """ - def save_for_backward(self, *tensors): + container: tuple[Tensor, ...] + not_inplace_tensors: tuple[Tensor, ...] + non_differentiable: tuple[Tensor, ...] + materialize_grads: bool + + def save_for_backward(self, *tensors: Tensor) -> None: """ Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors. @@ -90,7 +100,7 @@ def save_for_backward(self, *tensors): """ self.container = tensors - def saved_tensor(self): + def saved_tensor(self) -> tuple[Tensor, ...]: """ Get the tensors stored by ``save_for_backward``. @@ -122,7 +132,7 @@ def saved_tensor(self): """ return self.container - def mark_not_inplace(self, *args): + def mark_not_inplace(self, *args: Tensor) -> None: """ Marks inputs as not inplace. This should be called at most once, only from inside the `forward` method, @@ -163,7 +173,7 @@ def mark_not_inplace(self, *args): """ self.not_inplace_tensors = args - def mark_non_differentiable(self, *args): + def mark_non_differentiable(self, *args: Tensor) -> None: """ Marks outputs as non-differentiable. This should be called at most once, only from inside the `forward` method, @@ -203,7 +213,7 @@ def mark_non_differentiable(self, *args): """ self.non_differentiable = args - def set_materialize_grads(self, value: bool): + def set_materialize_grads(self, value: bool) -> None: """ Sets whether to materialize output grad tensors. Default is True. @@ -267,7 +277,7 @@ def __init__(cls, name, bases, attrs): return super().__init__(name, bases, attrs) -class PyLayer(with_metaclass(PyLayerMeta, core.eager.PyLayer, PyLayerContext)): +class PyLayer(core.eager.PyLayer, PyLayerContext, metaclass=PyLayerMeta): """ Paddle implements Python custom operators on the PaddlePaddle framework by creating a subclass of ``PyLayer``, which must comply with the following rules: @@ -322,7 +332,9 @@ class PyLayer(with_metaclass(PyLayerMeta, core.eager.PyLayer, PyLayerContext)): """ @staticmethod - def forward(ctx, *args, **kwargs): + def forward( + ctx: PyLayerContext, *args: Any, **kwargs: Any + ) -> Tensor | Sequence[Tensor]: """ It is to be overloaded by subclasses. It must accept a object of :ref:`api_paddle_autograd_PyLayerContext` as the first argument, followed by any number of arguments (tensors or other types). @@ -361,7 +373,7 @@ def forward(ctx, *args, **kwargs): ) @staticmethod - def backward(ctx, *args): + def backward(ctx: PyLayerContext, *args: Any) -> Tensor | Sequence[Tensor]: """ This is a function to calculate the gradient. It is to be overloaded by subclasses. It must accept a object of :ref:`api_paddle_autograd_PyLayerContext` as the first @@ -402,8 +414,10 @@ def backward(ctx, *args): ) -def once_differentiable(backward): - def wrapper(ctx, *args): +def once_differentiable( + backward: Callable[Concatenate[PyLayerContext, ...], _RetT] +) -> Callable[Concatenate[PyLayerContext, ...], _RetT]: + def wrapper(ctx: PyLayerContext, *args: Any) -> _RetT: with paddle.base.dygraph.no_grad(): outputs = backward(ctx, *args) return outputs