Skip to content

Commit

Permalink
[Typing][B-42][BUAA] Add type annotations for `python/paddle/autograd…
Browse files Browse the repository at this point in the history
…/py_layer.py` (PaddlePaddle#66328)
  • Loading branch information
Fripping authored and Dale1314 committed Jul 28, 2024
1 parent 34ee7fa commit 5432739
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions python/paddle/autograd/py_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Sequence, TypeVar

from typing_extensions import Concatenate

import paddle
from paddle.base import core

__all__ = []
if TYPE_CHECKING:
from paddle import Tensor


def with_metaclass(meta, *bases):
class impl(meta):
def __new__(cls, name, temp_bases, attrs):
return meta(name, bases, attrs)
__all__ = []


return type.__new__(impl, "impl", (), {})
_RetT = TypeVar('_RetT')


class PyLayerContext:
Expand Down Expand Up @@ -52,7 +57,12 @@ class PyLayerContext:
... return grad
"""

def save_for_backward(self, *tensors):
container: tuple[Tensor, ...]
not_inplace_tensors: tuple[Tensor, ...]
non_differentiable: tuple[Tensor, ...]
materialize_grads: bool

def save_for_backward(self, *tensors: Tensor) -> None:
"""
Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors.
Expand Down Expand Up @@ -90,7 +100,7 @@ def save_for_backward(self, *tensors):
"""
self.container = tensors

def saved_tensor(self):
def saved_tensor(self) -> tuple[Tensor, ...]:
"""
Get the tensors stored by ``save_for_backward``.
Expand Down Expand Up @@ -122,7 +132,7 @@ def saved_tensor(self):
"""
return self.container

def mark_not_inplace(self, *args):
def mark_not_inplace(self, *args: Tensor) -> None:
"""
Marks inputs as not inplace.
This should be called at most once, only from inside the `forward` method,
Expand Down Expand Up @@ -163,7 +173,7 @@ def mark_not_inplace(self, *args):
"""
self.not_inplace_tensors = args

def mark_non_differentiable(self, *args):
def mark_non_differentiable(self, *args: Tensor) -> None:
"""
Marks outputs as non-differentiable.
This should be called at most once, only from inside the `forward` method,
Expand Down Expand Up @@ -203,7 +213,7 @@ def mark_non_differentiable(self, *args):
"""
self.non_differentiable = args

def set_materialize_grads(self, value: bool):
def set_materialize_grads(self, value: bool) -> None:
"""
Sets whether to materialize output grad tensors. Default is True.
Expand Down Expand Up @@ -267,7 +277,7 @@ def __init__(cls, name, bases, attrs):
return super().__init__(name, bases, attrs)


class PyLayer(with_metaclass(PyLayerMeta, core.eager.PyLayer, PyLayerContext)):
class PyLayer(core.eager.PyLayer, PyLayerContext, metaclass=PyLayerMeta):
"""
Paddle implements Python custom operators on the PaddlePaddle framework by creating a subclass of
``PyLayer``, which must comply with the following rules:
Expand Down Expand Up @@ -322,7 +332,9 @@ class PyLayer(with_metaclass(PyLayerMeta, core.eager.PyLayer, PyLayerContext)):
"""

@staticmethod
def forward(ctx, *args, **kwargs):
def forward(
ctx: PyLayerContext, *args: Any, **kwargs: Any
) -> Tensor | Sequence[Tensor]:
"""
It is to be overloaded by subclasses. It must accept a object of :ref:`api_paddle_autograd_PyLayerContext` as
the first argument, followed by any number of arguments (tensors or other types).
Expand Down Expand Up @@ -361,7 +373,7 @@ def forward(ctx, *args, **kwargs):
)

@staticmethod
def backward(ctx, *args):
def backward(ctx: PyLayerContext, *args: Any) -> Tensor | Sequence[Tensor]:
"""
This is a function to calculate the gradient. It is to be overloaded by subclasses.
It must accept a object of :ref:`api_paddle_autograd_PyLayerContext` as the first
Expand Down Expand Up @@ -402,8 +414,10 @@ def backward(ctx, *args):
)


def once_differentiable(backward):
def wrapper(ctx, *args):
def once_differentiable(
backward: Callable[Concatenate[PyLayerContext, ...], _RetT]
) -> Callable[Concatenate[PyLayerContext, ...], _RetT]:
def wrapper(ctx: PyLayerContext, *args: Any) -> _RetT:
with paddle.base.dygraph.no_grad():
outputs = backward(ctx, *args)
return outputs
Expand Down

0 comments on commit 5432739

Please sign in to comment.