Skip to content

Commit

Permalink
【PIR/Dy2static】Fix pir test ---- PART II (#59532)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: chenzhiyang <[email protected]>
  • Loading branch information
2742195759 and changeyoung98 authored Dec 1, 2023
1 parent eeb35ea commit 2bf7567
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 43 deletions.
2 changes: 1 addition & 1 deletion python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ def forward(self, input):
)
else:
act_op = getattr(_C_ops, self._act)
return act_op(input)
return act_op(batch_norm_out)
else:
# create output
# mean and mean_out share the same memory
Expand Down
31 changes: 7 additions & 24 deletions test/dygraph_to_static/test_bmn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
static_guard,
test_pt_only,
test_legacy_and_pt_and_pir,
)
from predictor_utils import PredictorTools

Expand Down Expand Up @@ -625,16 +625,6 @@ def val_bmn(model, args):
float(pem_cls_loss),
]

print(
f'[VALID] iter {batch_id} '
+ '\tLoss = {}, \ttem_loss = {}, \tpem_reg_loss = {}, \tpem_cls_loss = {}'.format(
'%f' % float(avg_loss),
'%f' % float(tem_loss),
'%f' % float(pem_reg_loss),
'%f' % float(pem_cls_loss),
)
)

if batch_id == args.valid_batch_num:
break
return loss_data
Expand Down Expand Up @@ -722,17 +712,6 @@ def train_bmn(self, args, to_static):
float(pem_cls_loss),
]

if args.log_interval > 0 and (
batch_id % args.log_interval == 0
):
print(
f'[TRAIN] Epoch {epoch}, iter {batch_id} '
+ f'\tLoss = {float(avg_loss):f}, '
+ f'\ttem_loss = {float(tem_loss):f}, '
+ f'\tpem_reg_loss = {float(pem_reg_loss):f}, '
+ f'\tpem_cls_loss = {float(pem_cls_loss):f}'
)

# validation
if batch_id % args.valid_interval == 0 and batch_id > 0:
bmn.eval()
Expand All @@ -741,7 +720,11 @@ def train_bmn(self, args, to_static):
loss_data += val_loss_data

if batch_id == args.train_batch_num:
if to_static:
# TODO(@xiongkun): open after save / load supported in pir.
if (
to_static
and not paddle.base.framework.use_pir_api()
):
paddle.jit.save(bmn, self.model_save_prefix)
else:
paddle.save(
Expand All @@ -751,7 +734,7 @@ def train_bmn(self, args, to_static):
break
return np.array(loss_data)

@test_pt_only
@test_legacy_and_pt_and_pir
def test_train_pir(self):
static_res = self.train_bmn(self.args, to_static=True)
dygraph_res = self.train_bmn(self.args, to_static=False)
Expand Down
35 changes: 17 additions & 18 deletions test/dygraph_to_static/test_to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
ToStaticMode,
disable_test_case,
test_legacy_and_pt_and_pir,
test_legacy_only,
test_pir_only,
)

import paddle
Expand Down Expand Up @@ -184,6 +182,7 @@ def test_to_tensor_err_log(self):


class TestStatic(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_static(self):
paddle.enable_static()
main_prog = paddle.static.Program()
Expand All @@ -194,14 +193,16 @@ def test_static(self):
else:
place = paddle.CPUPlace()

paddle.set_default_dtype("float64")
x = paddle.to_tensor(
paddle.randn([5, 2]),
dtype='float64',
stop_gradient=False,
place=place,
)

out = paddle.static.nn.fc(x, 1)
fc_net = paddle.nn.Linear(2, 1)
out = fc_net(x)

sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(out))
Expand All @@ -212,29 +213,27 @@ def test_static(self):


class TestInt16(Dy2StTestBase):
@test_legacy_only
@test_legacy_and_pt_and_pir
def test_static(self):
import numpy as np

paddle.enable_static()
data = np.array([1, 2], dtype="int16")
x = paddle.to_tensor(data)
self.assertTrue(x.dtype == paddle.framework.core.VarDesc.VarType.INT16)

y = paddle.to_tensor([1, 2], dtype="int16")
self.assertTrue(y.dtype == paddle.framework.core.VarDesc.VarType.INT16)

@test_pir_only
def test_static_pir(self):
import numpy as np

paddle.enable_static()
data = np.array([1, 2], dtype="int16")
x = paddle.to_tensor(data)
self.assertTrue(x.dtype == paddle.base.libpaddle.DataType.INT16)
if paddle.base.framework.use_pir_api():
self.assertTrue(x.dtype == paddle.base.libpaddle.DataType.INT16)
else:
self.assertTrue(
x.dtype == paddle.framework.core.VarDesc.VarType.INT16
)

y = paddle.to_tensor([1, 2], dtype="int16")
self.assertTrue(y.dtype == paddle.base.libpaddle.DataType.INT16)
if paddle.base.framework.use_pir_api():
self.assertTrue(y.dtype == paddle.base.libpaddle.DataType.INT16)
else:
self.assertTrue(
y.dtype == paddle.framework.core.VarDesc.VarType.INT16
)


if __name__ == '__main__':
Expand Down

0 comments on commit 2bf7567

Please sign in to comment.