Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR/Dy2static】Fix pir test ---- PART II #59532

Merged
merged 19 commits into from
Dec 1, 2023
2 changes: 1 addition & 1 deletion python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ def forward(self, input):
)
else:
act_op = getattr(_C_ops, self._act)
return act_op(input)
return act_op(batch_norm_out)
else:
# create output
# mean and mean_out share the same memory
Expand Down
31 changes: 7 additions & 24 deletions test/dygraph_to_static/test_bmn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
static_guard,
test_pt_only,
test_legacy_and_pt_and_pir,
)
from predictor_utils import PredictorTools

Expand Down Expand Up @@ -625,16 +625,6 @@ def val_bmn(model, args):
float(pem_cls_loss),
]

print(
f'[VALID] iter {batch_id} '
+ '\tLoss = {}, \ttem_loss = {}, \tpem_reg_loss = {}, \tpem_cls_loss = {}'.format(
'%f' % float(avg_loss),
'%f' % float(tem_loss),
'%f' % float(pem_reg_loss),
'%f' % float(pem_cls_loss),
)
)

if batch_id == args.valid_batch_num:
break
return loss_data
Expand Down Expand Up @@ -722,17 +712,6 @@ def train_bmn(self, args, to_static):
float(pem_cls_loss),
]

if args.log_interval > 0 and (
batch_id % args.log_interval == 0
):
print(
f'[TRAIN] Epoch {epoch}, iter {batch_id} '
+ f'\tLoss = {float(avg_loss):f}, '
+ f'\ttem_loss = {float(tem_loss):f}, '
+ f'\tpem_reg_loss = {float(pem_reg_loss):f}, '
+ f'\tpem_cls_loss = {float(pem_cls_loss):f}'
)

# validation
if batch_id % args.valid_interval == 0 and batch_id > 0:
bmn.eval()
Expand All @@ -741,7 +720,11 @@ def train_bmn(self, args, to_static):
loss_data += val_loss_data

if batch_id == args.train_batch_num:
if to_static:
# TODO(@xiongkun): open after save / load supported in pir.
if (
to_static
and not paddle.base.framework.use_pir_api()
):
paddle.jit.save(bmn, self.model_save_prefix)
else:
paddle.save(
Expand All @@ -751,7 +734,7 @@ def train_bmn(self, args, to_static):
break
return np.array(loss_data)

@test_pt_only
@test_legacy_and_pt_and_pir
def test_train_pir(self):
static_res = self.train_bmn(self.args, to_static=True)
dygraph_res = self.train_bmn(self.args, to_static=False)
Expand Down
35 changes: 17 additions & 18 deletions test/dygraph_to_static/test_to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
ToStaticMode,
disable_test_case,
test_legacy_and_pt_and_pir,
test_legacy_only,
test_pir_only,
)

import paddle
Expand Down Expand Up @@ -184,6 +182,7 @@ def test_to_tensor_err_log(self):


class TestStatic(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_static(self):
paddle.enable_static()
main_prog = paddle.static.Program()
Expand All @@ -194,14 +193,16 @@ def test_static(self):
else:
place = paddle.CPUPlace()

paddle.set_default_dtype("float64")
x = paddle.to_tensor(
paddle.randn([5, 2]),
dtype='float64',
stop_gradient=False,
place=place,
)

out = paddle.static.nn.fc(x, 1)
fc_net = paddle.nn.Linear(2, 1)
out = fc_net(x)

sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(out))
Expand All @@ -212,29 +213,27 @@ def test_static(self):


class TestInt16(Dy2StTestBase):
@test_legacy_only
@test_legacy_and_pt_and_pir
def test_static(self):
import numpy as np

paddle.enable_static()
data = np.array([1, 2], dtype="int16")
x = paddle.to_tensor(data)
self.assertTrue(x.dtype == paddle.framework.core.VarDesc.VarType.INT16)

y = paddle.to_tensor([1, 2], dtype="int16")
self.assertTrue(y.dtype == paddle.framework.core.VarDesc.VarType.INT16)

@test_pir_only
def test_static_pir(self):
import numpy as np

paddle.enable_static()
data = np.array([1, 2], dtype="int16")
x = paddle.to_tensor(data)
self.assertTrue(x.dtype == paddle.base.libpaddle.DataType.INT16)
if paddle.base.framework.use_pir_api():
self.assertTrue(x.dtype == paddle.base.libpaddle.DataType.INT16)
else:
self.assertTrue(
x.dtype == paddle.framework.core.VarDesc.VarType.INT16
)

y = paddle.to_tensor([1, 2], dtype="int16")
self.assertTrue(y.dtype == paddle.base.libpaddle.DataType.INT16)
if paddle.base.framework.use_pir_api():
self.assertTrue(y.dtype == paddle.base.libpaddle.DataType.INT16)
else:
self.assertTrue(
y.dtype == paddle.framework.core.VarDesc.VarType.INT16
)


if __name__ == '__main__':
Expand Down