Skip to content

Commit

Permalink
[Dy2St] Refactor dy2st unittest decorators name - Part 1 (PaddlePaddl…
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored Oct 27, 2023
1 parent b068738 commit 62cc5b2
Show file tree
Hide file tree
Showing 19 changed files with 123 additions and 184 deletions.
31 changes: 12 additions & 19 deletions test/dygraph_to_static/test_basic_api_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import (
dy2static_unittest,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, compare_legacy_with_pir

import paddle
from paddle import base, to_tensor
Expand Down Expand Up @@ -72,8 +69,7 @@ def dyfunc_bool_to_tensor(x):
return paddle.to_tensor(True)


@dy2static_unittest
class TestDygraphBasicApi_ToVariable(unittest.TestCase):
class TestDygraphBasicApi_ToVariable(Dy2StTestBase):
def setUp(self):
self.input = np.ones(5).astype("int32")
self.test_funcs = [
Expand All @@ -96,7 +92,7 @@ def get_dygraph_output(self):
res = self.dygraph_func(self.input).numpy()
return res

@test_and_compare_with_new_ir(True)
@compare_legacy_with_pir
def get_static_output(self):
main_program = base.Program()
main_program.random_seed = SEED
Expand Down Expand Up @@ -234,8 +230,7 @@ def dyfunc_Prelu(input):
return res


@dy2static_unittest
class TestDygraphBasicApi(unittest.TestCase):
class TestDygraphBasicApi(Dy2StTestBase):
# Compare results of dynamic graph and transformed static graph function which only
# includes basic Api.

Expand All @@ -252,7 +247,7 @@ def get_dygraph_output(self):

return res

@test_and_compare_with_new_ir(True)
@compare_legacy_with_pir
def get_static_output(self):
startup_program = base.Program()
startup_program.random_seed = SEED
Expand Down Expand Up @@ -286,7 +281,7 @@ def get_dygraph_output(self):
res = self.dygraph_func(self.input1, self.input2).numpy()
return res

@test_and_compare_with_new_ir(True)
@compare_legacy_with_pir
def get_static_output(self):
startup_program = base.Program()
startup_program.random_seed = SEED
Expand Down Expand Up @@ -401,8 +396,7 @@ def dyfunc_PolynomialDecay():
return paddle.to_tensor(lr)


@dy2static_unittest
class TestDygraphBasicApi_CosineDecay(unittest.TestCase):
class TestDygraphBasicApi_CosineDecay(Dy2StTestBase):
def setUp(self):
self.dygraph_func = dyfunc_CosineDecay

Expand All @@ -413,7 +407,7 @@ def get_dygraph_output(self):
res = self.dygraph_func().numpy()
return res

@test_and_compare_with_new_ir(True)
@compare_legacy_with_pir
def get_static_output(self):
startup_program = base.Program()
startup_program.random_seed = SEED
Expand Down Expand Up @@ -444,7 +438,7 @@ def get_dygraph_output(self):
res = self.dygraph_func()
return res

@test_and_compare_with_new_ir(True)
@compare_legacy_with_pir
def get_static_output(self):
startup_program = base.Program()
startup_program.random_seed = SEED
Expand All @@ -471,7 +465,7 @@ def get_dygraph_output(self):
res = self.dygraph_func()
return res

@test_and_compare_with_new_ir(True)
@compare_legacy_with_pir
def get_static_output(self):
startup_program = base.Program()
startup_program.random_seed = SEED
Expand All @@ -498,7 +492,7 @@ def get_dygraph_output(self):
res = self.dygraph_func()
return res

@test_and_compare_with_new_ir(True)
@compare_legacy_with_pir
def get_static_output(self):
startup_program = base.Program()
startup_program.random_seed = SEED
Expand Down Expand Up @@ -545,8 +539,7 @@ def _dygraph_fn():
np.random.random(1)


@dy2static_unittest
class TestDygraphApiRecognition(unittest.TestCase):
class TestDygraphApiRecognition(Dy2StTestBase):
def setUp(self):
self.src = inspect.getsource(_dygraph_fn)
self.root = gast.parse(self.src)
Expand Down
15 changes: 7 additions & 8 deletions test/dygraph_to_static/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
import numpy as np
from bert_dygraph_model import PretrainModelLayer
from bert_utils import get_bert_config, get_feed_data_reader
from dygraph_to_static_util import (
ast_only_test,
dy2static_unittest,
test_with_new_ir,
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_ast_only,
test_pir_only,
)
from predictor_utils import PredictorTools

Expand Down Expand Up @@ -78,8 +78,7 @@ def __len__(self):
return len(self.src_ids)


@dy2static_unittest
class TestBert(unittest.TestCase):
class TestBert(Dy2StTestBase):
def setUp(self):
self.bert_config = get_bert_config()
self.data_reader = get_feed_data_reader(self.bert_config)
Expand Down Expand Up @@ -266,7 +265,7 @@ def predict_analysis_inference(self, data):
out = output()
return out

@test_with_new_ir
@test_pir_only
def test_train_new_ir(self):
static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader
Expand All @@ -277,7 +276,7 @@ def test_train_new_ir(self):
np.testing.assert_allclose(static_loss, dygraph_loss, rtol=1e-05)
np.testing.assert_allclose(static_ppl, dygraph_ppl, rtol=1e-05)

@ast_only_test
@test_ast_only
def test_train(self):
static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader
Expand Down
7 changes: 3 additions & 4 deletions test/dygraph_to_static/test_bmn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import dy2static_unittest, test_with_new_ir
from dygraph_to_static_utils_new import Dy2StTestBase, test_pir_only
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -637,8 +637,7 @@ def val_bmn(model, args):
return loss_data


@dy2static_unittest
class TestTrain(unittest.TestCase):
class TestTrain(Dy2StTestBase):
def setUp(self):
self.args = Args()
self.place = (
Expand Down Expand Up @@ -751,7 +750,7 @@ def train_bmn(self, args, place, to_static):
break
return np.array(loss_data)

@test_with_new_ir
@test_pir_only
def test_train_new_ir(self):
static_res = self.train_bmn(self.args, self.place, to_static=True)
dygraph_res = self.train_bmn(self.args, self.place, to_static=False)
Expand Down
10 changes: 4 additions & 6 deletions test/dygraph_to_static/test_break_continue.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only

import paddle
from paddle import base
Expand All @@ -26,14 +26,13 @@
np.random.seed(SEED)


@dy2static_unittest
class TestDy2staticException(unittest.TestCase):
class TestDy2staticException(Dy2StTestBase):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = None
self.error = "Your if/else have different number of return value."

@ast_only_test
@test_ast_only
def test_error(self):
if self.dyfunc:
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
Expand Down Expand Up @@ -205,8 +204,7 @@ def test_optim_break_in_while(x):
return x


@dy2static_unittest
class TestContinueInFor(unittest.TestCase):
class TestContinueInFor(Dy2StTestBase):
def setUp(self):
self.input = np.zeros(1).astype('int64')
self.place = (
Expand Down
12 changes: 5 additions & 7 deletions test/dygraph_to_static/test_build_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase, test_ast_only
from test_resnet import ResNetHelper

import paddle


@dy2static_unittest
class TestResnetWithPass(unittest.TestCase):
class TestResnetWithPass(Dy2StTestBase):
def setUp(self):
self.build_strategy = paddle.static.BuildStrategy()
self.build_strategy.fuse_elewise_add_act_ops = True
Expand Down Expand Up @@ -62,7 +61,7 @@ def verify_predict(self):
err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.',
)

@ast_only_test
@test_ast_only
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand All @@ -74,7 +73,7 @@ def test_resnet(self):
)
self.verify_predict()

@ast_only_test
@test_ast_only
def test_in_static_mode_mkldnn(self):
paddle.base.set_flags({'FLAGS_use_mkldnn': True})
try:
Expand All @@ -84,8 +83,7 @@ def test_in_static_mode_mkldnn(self):
paddle.base.set_flags({'FLAGS_use_mkldnn': False})


@dy2static_unittest
class TestError(unittest.TestCase):
class TestError(Dy2StTestBase):
def test_type_error(self):
def foo(x):
out = x + 1
Expand Down
14 changes: 5 additions & 9 deletions test/dygraph_to_static/test_cache_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections import Counter

import numpy as np
from dygraph_to_static_util import dy2static_unittest
from dygraph_to_static_utils_new import Dy2StTestBase
from test_fetch_feed import Linear, Pool2D

import paddle
Expand All @@ -25,8 +25,7 @@
from paddle.jit.dy2static import convert_to_static


@dy2static_unittest
class TestCacheProgram(unittest.TestCase):
class TestCacheProgram(Dy2StTestBase):
def setUp(self):
self.batch_num = 5
self.dygraph_class = Pool2D
Expand Down Expand Up @@ -76,8 +75,7 @@ def setUp(self):
self.data = np.random.random((4, 10)).astype('float32')


@dy2static_unittest
class TestCacheProgramWithOptimizer(unittest.TestCase):
class TestCacheProgramWithOptimizer(Dy2StTestBase):
def setUp(self):
self.dygraph_class = Linear
self.data = np.random.random((4, 10)).astype('float32')
Expand Down Expand Up @@ -126,8 +124,7 @@ def simple_func(x):
return mean


@dy2static_unittest
class TestConvertWithCache(unittest.TestCase):
class TestConvertWithCache(Dy2StTestBase):
def test_cache(self):
static_func = convert_to_static(simple_func)
# Get transformed function from cache.
Expand Down Expand Up @@ -157,8 +154,7 @@ def sum_under_while(limit):
return ret_sum


@dy2static_unittest
class TestToOutputWithCache(unittest.TestCase):
class TestToOutputWithCache(Dy2StTestBase):
def test_output(self):
with base.dygraph.guard():
ret = sum_even_until_limit(80, 10)
Expand Down
10 changes: 3 additions & 7 deletions test/dygraph_to_static/test_cinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import (
dy2static_unittest,
test_and_compare_with_new_ir,
)
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir

import paddle

Expand All @@ -45,8 +42,7 @@ def apply_to_static(net, use_cinn):
return paddle.jit.to_static(net, build_strategy=build_strategy)


@dy2static_unittest
class TestCINN(unittest.TestCase):
class TestCINN(Dy2StTestBase):
def setUp(self):
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
Expand Down Expand Up @@ -83,7 +79,7 @@ def train(self, use_cinn):

return res

@test_and_compare_with_new_ir(False)
@test_legacy_and_pir
def test_cinn(self):
dy_res = self.train(use_cinn=False)
cinn_res = self.train(use_cinn=True)
Expand Down
Loading

0 comments on commit 62cc5b2

Please sign in to comment.