Skip to content

Commit

Permalink
Merge pull request #14 from feifei-111/jitsvae
Browse files Browse the repository at this point in the history
support jit.save
  • Loading branch information
2742195759 authored Jun 27, 2023
2 parents dfef6dc + 06609f5 commit 7ecbf64
Show file tree
Hide file tree
Showing 15 changed files with 59 additions and 38 deletions.
53 changes: 44 additions & 9 deletions python/paddle/jit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import inspect
import threading
from typing import Any
import types

import paddle
from paddle.fluid import core, dygraph
Expand Down Expand Up @@ -1043,7 +1044,9 @@ def fun(inputs):
concrete_program = None
for attr_func in functions:
if isinstance(layer, Layer):
static_func = getattr(inner_layer, attr_func, None)
static_func = get_ast_static_function(
getattr(inner_layer, attr_func, None)
)
if isinstance(static_func, StaticFunction):
if static_func.is_property:
# property method to be exported
Expand Down Expand Up @@ -1076,7 +1079,9 @@ def fun(inputs):
input_spec, inner_input_spec
)
static_forward = to_static(
inner_layer.forward, input_spec=inner_input_spec
inner_layer.forward,
input_spec=inner_input_spec,
enable_fallback=False,
)
concrete_program = (
static_forward.concrete_program_specify_input_spec(
Expand All @@ -1092,14 +1097,16 @@ def fun(inputs):
else:
# When layer is a function
if isinstance(attr_func, StaticFunction):
if attr_func.is_property:
static_func = get_ast_static_function(attr_func)

if static_func.is_property:
# property method to be exported
immediate_val = attr_func()
property_vals.append((immediate_val, attr_func))
immediate_val = static_func()
property_vals.append((immediate_val, static_func))
continue

concrete_program = (
attr_func.concrete_program_specify_input_spec(
static_func.concrete_program_specify_input_spec(
inner_input_spec, is_prim_infer=is_prim_infer
)
)
Expand All @@ -1109,7 +1116,9 @@ def fun(inputs):
input_spec, inner_input_spec
)
static_function = to_static(
attr_func, input_spec=inner_input_spec
static_func,
input_spec=inner_input_spec,
enable_fallback=False,
)
concrete_program = static_function.concrete_program

Expand All @@ -1125,9 +1134,9 @@ def fun(inputs):
if isinstance(inner_layer, Layer):
dygraph_state_dict = inner_layer.to_static_state_dict()
elif isinstance(attr_func, StaticFunction):
if attr_func._class_instance:
if static_func._class_instance:
dygraph_state_dict = (
attr_func._class_instance.to_static_state_dict()
static_func._class_instance.to_static_state_dict()
)

if dygraph_state_dict:
Expand Down Expand Up @@ -1886,3 +1895,29 @@ def get_feed_fetch(all_vars, partial_vars):
clip_extra=clip_extra,
legacy_format=legacy_format,
)


def get_ast_static_function(function):
if isinstance(function, SymbolicStaticFunction):
if function._class_instance:
dygraph_function = types.MethodType(
function._dygraph_function, function._class_instance
)
else:
dygraph_function = function._dygraph_function

if function._function_spec._input_spec is None:
ast_static_function = ASTStaticFunction(
dygraph_function,
function.last_call_input_spec,
**function._kwargs,
)
return ast_static_function
else:
ast_static_function = ASTStaticFunction(
dygraph_function,
function._function_spec._input_spec,
**function._kwargs,
)
return ast_static_function
return function
11 changes: 11 additions & 0 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,19 @@ def __init__(self, function, input_spec=None, **kwargs):
"1. You can disable fallback mode by `paddle.jit.to_static(enable_fallback=False)` to switch to AST to static, then you can assign input spec.\n"
)
super().__init__(function, input_spec, **kwargs)
self.last_call_input_spec = None

def _perform_call(self, *args, **kwargs):
<<<<<<< HEAD
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
(
input_args_with_spec,
input_kwargs_with_spec,
) = self._function_spec.args_to_input_spec(args, kwargs)
self.last_call_input_spec = input_args_with_spec

=======
>>>>>>> dfef6dc85728412c35f9af3a82ddba2f57713101
from sot import symbolic_translate

build_strategy = self._kwargs.get("build_strategy", None)
Expand Down
1 change: 0 additions & 1 deletion test/dygraph_to_static/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def test_train(self):

self.verify_predict()

@ast_only_test
def test_train_composite(self):
core._set_prim_backward_enabled(True)
# core._add_skip_comp_ops("layer_norm")
Expand Down
3 changes: 1 addition & 2 deletions test/dygraph_to_static/test_bmn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from dygraph_to_static_util import dy2static_unittest
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -752,7 +752,6 @@ def train_bmn(self, args, place, to_static):
break
return np.array(loss_data)

@ast_only_test
def test_train(self):

static_res = self.train_bmn(self.args, self.place, to_static=True)
Expand Down
3 changes: 1 addition & 2 deletions test/dygraph_to_static/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from dygraph_to_static_util import dy2static_unittest

import paddle

Expand Down Expand Up @@ -106,7 +106,6 @@ def _run(self, to_static):

return out

@ast_only_test
def test_train(self):
paddle.jit.set_code_level(100)
dy_out = self._run(to_static=False)
Expand Down
1 change: 0 additions & 1 deletion test/dygraph_to_static/test_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def test_case_training(self):
np.testing.assert_allclose(u_net(self.x).numpy(), [1, 1])
assert u_net.training is False, "Training must be false."

@ast_only_test
def test_case_save_error(self):
"""
test the save will raise error.
Expand Down
3 changes: 0 additions & 3 deletions test/dygraph_to_static/test_for_enumerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test

import paddle
from paddle import fluid
Expand Down Expand Up @@ -564,13 +563,11 @@ def setUp(self):
def tearDown(self):
self.temp_dir.cleanup()

@ast_only_test
def test_for_zip_error(self):
with self.assertRaises(RuntimeError):
model_path = os.path.join(self.temp_dir.name, 'for_zip_error')
paddle.jit.save(for_zip_error, model_path)

@ast_only_test
def test_for_zip(self):
model_path = os.path.join(self.temp_dir.name, 'for_zip')
paddle.jit.save(for_zip, model_path)
Expand Down
4 changes: 1 addition & 3 deletions test/dygraph_to_static/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from dygraph_to_static_util import dy2static_unittest

import paddle

Expand Down Expand Up @@ -101,7 +101,6 @@ def setUp(self):
def tearDown(self):
self.temp_dir.cleanup()

@ast_only_test
def test_save_infer_program(self):
self.setUp() # make self.func change to ast mode
input_spec = [
Expand All @@ -114,7 +113,6 @@ def test_save_infer_program(self):
load_res = load_func(self.x).numpy()
np.testing.assert_allclose(origin_res, load_res, rtol=1e-05)

@ast_only_test
def test_save_train_program(self):
self.setUp() # make self.func change to ast mode
grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0)
Expand Down
2 changes: 0 additions & 2 deletions test/dygraph_to_static/test_layer_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test

import paddle

Expand Down Expand Up @@ -83,7 +82,6 @@ def load_train(self):
out = net(self.x)
return float(out)

@ast_only_test
def test_hook(self):
dy_out = self.train_net(to_static=False)
st_out = self.train_net(to_static=True)
Expand Down
2 changes: 0 additions & 2 deletions test/dygraph_to_static/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test

import paddle
import paddle.nn.functional as F
Expand Down Expand Up @@ -462,7 +461,6 @@ def forward(self, x):


class TestForLoopMeetDict(unittest.TestCase):
@ast_only_test
def test_start(self):
net = Net()
model = paddle.jit.to_static(
Expand Down
2 changes: 0 additions & 2 deletions test/dygraph_to_static/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def setUp(self):
def tearDown(self):
self.temp_dir.cleanup()

@ast_only_test
def test_save_in_eval(self):
paddle.jit.enable_to_static(True)
net = LinearNet()
Expand Down Expand Up @@ -191,7 +190,6 @@ def setUp(self):
def tearDown(self):
self.temp_dir.cleanup()

@ast_only_test
def test_eval_after_save(self):
x = paddle.randn((2, 10, 12)).astype('float32')
net = Net(12, 2)
Expand Down
2 changes: 0 additions & 2 deletions test/dygraph_to_static/test_mobile_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -714,7 +713,6 @@ def assert_same_predict(self, model_name):
),
)

@ast_only_test
def test_mobile_net(self):
# MobileNet-V1
self.assert_same_loss("MobileNetV1")
Expand Down
3 changes: 0 additions & 3 deletions test/dygraph_to_static/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,6 @@ def verify_predict(self):
),
)

@ast_only_test
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand Down Expand Up @@ -443,7 +442,6 @@ def test_resnet_composite_backward(self):
),
)

@ast_only_test
def test_resnet_composite_forward_backward(self):
core._set_prim_all_enabled(True)
static_loss = self.train(to_static=True)
Expand All @@ -458,7 +456,6 @@ def test_resnet_composite_forward_backward(self):
),
)

@ast_only_test
def test_in_static_mode_mkldnn(self):
fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
Expand Down
4 changes: 0 additions & 4 deletions test/dygraph_to_static/test_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -413,7 +412,6 @@ def verify_predict(self):
),
)

@ast_only_test
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand All @@ -427,7 +425,6 @@ def test_resnet(self):
)
self.verify_predict()

@ast_only_test
def test_resnet_composite(self):
core._set_prim_backward_enabled(True)
core._add_skip_comp_ops("batch_norm")
Expand All @@ -443,7 +440,6 @@ def test_resnet_composite(self):
),
)

@ast_only_test
def test_in_static_mode_mkldnn(self):
paddle.fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
Expand Down
3 changes: 1 addition & 2 deletions test/dygraph_to_static/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Dict, List, Tuple

import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from dygraph_to_static_util import dy2static_unittest

import paddle

Expand Down Expand Up @@ -94,7 +94,6 @@ def run_dy(self):
out, _ = self.net(self.x)
return out

@ast_only_test
def test_type(self):
self.net = self.build_net()
out = self.run_dy()
Expand Down

0 comments on commit 7ecbf64

Please sign in to comment.