Skip to content

Commit

Permalink
[ Dy2Static ] convert_call support staticmethod for class. (#44983)
Browse files Browse the repository at this point in the history
* convert_call support staticmethod for class.

* while support for python container.
It is convenient to convert more dynamic graph codes into static graphs.

* cond support python container

* add unittest for staticmethod convert_call

* fix bugs

* add unittest for item interface

* fix bugs

* change to np.testing.assert_allclose

* code format

* fix comments.

* code format
  • Loading branch information
2742195759 authored Sep 9, 2022
1 parent 2b4f44d commit d0096ea
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def visit_FunctionDef(self, node):
# Remove the decorated name of dygraph_to_static
if hasattr(node, 'decorator_list'):
decorator_list = []
ignore_list = ["staticmethod"]
for d in node.decorator_list:
if isinstance(d, gast.Name) and d.id in ignore_list:
continue
if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES:
raise NotImplementedError(
"ProgramTranslator hasn't implemented multiple decorators. Please remove "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ def dyfunc(x):
elif isinstance(fn, StaticFunction):
_, fn = unwrap_decorators(fn)
global_functions.add(fn)
elif inspect.isclass(fn):
if isinstance(fn.__dict__.get(func.__name__, None),
staticmethod):
global_functions.add(
func) # Add func to ensure that we will convert

if func in global_functions:
converted_call = convert_to_static(func)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def convert_while_loop(cond,
Args:
cond(Callable): A callable object that returns a boolean variable to control whether to execute the loop body. It takes ``loop_vars`` as arguments.
body(Callable): A callable object that returns a tuple or list of variables with the same arguments ``loops_vars`` as ``cond`` .
loop_vars(list|tuple): A list or tuple of variables passed to ``cond`` and ``body`` .
get_args(callable): Get all arguments that needed in true_fn and false_fn.
set_args(callable): Update arguments that modified in trure_fn and false_fn.
return_name_ids(list[string], optional): the returned names.
push_pop_names(list[string], optional): the names on which called .append() or .pop().
Returns:
A list or tuple of variables which returned by ``body``.
Expand Down Expand Up @@ -306,7 +309,8 @@ def convert_ifelse(pred,
false_fn(callable): A callable to be performed if ``pred`` is false.
get_args(callable): Get all arguments that needed in true_fn and false_fn.
set_args(callable): Update arguments that modified in trure_fn and false_fn.
return_name_ids(list[string]): the returned names.
return_name_ids(list[string], optional): the returned names.
push_pop_names(list[string], optional): the names on which called .append() or .pop().
Returns:
``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` .
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ def __call__(self, *args, **kwargs):
try:
concrete_program, partial_program_layer = self.get_concrete_program(
*args, **kwargs, is_train=self._is_train_mode())

# 3. synchronize self.training attribute.
if isinstance(self._class_instance, layers.Layer):
partial_program_layer.training = self._class_instance.training
Expand Down
25 changes: 22 additions & 3 deletions python/paddle/fluid/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,30 @@ def append(self, var):
.format(self.type))
array_write(x=var, i=array_length(self), array=self)

@static_only
def _item(self):
"""
In order to be compatible with the item interface introduced by the dynamic graph, it does nothing but returns self.
It will check that the shape must be a 1-D tensor
"""
if len(self.shape) > 1:
raise TypeError(
"Required input var should be 1-D Variable, but received {}".
format(self.shape))
return self

@static_only
def pop(self, *args):
"""
**Notes**:
**The type variable must be LoD Tensor Array.
The type variable must be LoD Tensor Array.
When self is LoDTensorArray, calling pop is similar to Python's pop on list.
This interface is used to simplify dygraph to static graph operations.
Args:
self(Variable): The source variable, which must be LOD_TENSOR_ARRAY
*args: optional, a int means index.
Returns:
Variable: self[index]
"""
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import _run_paddle_pop
if self.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
Expand Down Expand Up @@ -410,6 +428,7 @@ def __impl__(self, other_var):
('cpu', cpu),
('cuda', cuda),
('append', append),
('item', _item),
('pop', pop),
('dim', lambda x: len(x.shape)),
('ndimension', lambda x: len(x.shape)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,23 @@ def dyfunc_with_third_library_logging(x_v):
return x_v


class A:

@staticmethod
def add(a, b):
"""
dygraph mode, return a numpy object.
static mode, return a variable object.
"""
return paddle.to_tensor(a.numpy() + b.numpy())


@paddle.jit.to_static
def dyfunc_with_staticmethod(x_v):
a = A()
return a.add(x_v, x_v)


class TestRecursiveCall1(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -188,6 +205,12 @@ def set_func(self):
self.dygraph_func = dyfunc_with_third_library_logging


class TestStaticMethod(TestRecursiveCall2):

def set_func(self):
self.dygraph_func = dyfunc_with_staticmethod


# Situation 2 : test not_to_static


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,29 @@ def test_to_static_numpy_report_error(self):
static_res = self._run(to_static=True)


@paddle.jit.to_static
def tensor_item(x):
x = paddle.to_tensor(x)
y = x.clone()
return y.item()


class TestTensorItem(unittest.TestCase):

def _run(self, to_static):
prog_trans = paddle.jit.ProgramTranslator()
prog_trans.enable(to_static)
x = paddle.ones([1])
if to_static:
return tensor_item(x).numpy()
return tensor_item(x)

def test_tensor_clone(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res)


@paddle.jit.to_static
def tensor_size(x):
x = paddle.to_tensor(x)
Expand Down

0 comments on commit d0096ea

Please sign in to comment.