Skip to content

Commit

Permalink
[Dy2St][NO.2] pir dy2st unittest fix test_bert - Part 1 (PaddlePadd…
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored and HermitSun committed Dec 21, 2023
1 parent 3383279 commit e98fe90
Showing 1 changed file with 78 additions and 68 deletions.
146 changes: 78 additions & 68 deletions test/dygraph_to_static/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,21 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_and_pt_and_pir,
test_sot_only,
)
from predictor_utils import PredictorTools

import paddle
from paddle import base
from paddle.base import core
from paddle.base.framework import unique_name
from paddle.framework import use_pir_api
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX

place = base.CUDAPlace(0) if base.is_compiled_with_cuda() else base.CPUPlace()
place = (
paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
)
SEED = 2020
STEP_NUM = 10
PRINT_STEP = 2
Expand Down Expand Up @@ -95,7 +100,7 @@ def tearDown(self):
self.temp_dir.cleanup()

def train(self, bert_config, data_reader, to_static):
with base.dygraph.guard(place):
with unique_name.guard():
base.default_main_program().random_seed = SEED
base.default_startup_program().random_seed = SEED

Expand Down Expand Up @@ -158,7 +163,9 @@ def train(self, bert_config, data_reader, to_static):
step_idx += 1
if step_idx == STEP_NUM:
if to_static:
paddle.jit.save(bert, self.model_save_prefix)
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(bert, self.model_save_prefix)
else:
paddle.save(
bert.state_dict(),
Expand All @@ -172,8 +179,7 @@ def train_dygraph(self, bert_config, data_reader):
return self.train(bert_config, data_reader, False)

def train_static(self, bert_config, data_reader):
with enable_to_static_guard(True):
return self.train(bert_config, data_reader, True)
return self.train(bert_config, data_reader, True)

def predict_static(self, data):
paddle.enable_static()
Expand All @@ -195,11 +201,12 @@ def predict_static(self, data):
fetch_list=fetch_targets,
)

paddle.disable_static()
return pred_res

def predict_dygraph(self, bert_config, data):
with enable_to_static_guard(False):
with base.dygraph.guard(place):
with unique_name.guard():
bert = PretrainModelLayer(
config=bert_config, weight_sharing=False, use_fp16=False
)
Expand All @@ -210,7 +217,7 @@ def predict_dygraph(self, bert_config, data):
bert.set_dict(model_dict)
bert.eval()

input_vars = [base.dygraph.to_variable(x) for x in data]
input_vars = [paddle.to_tensor(x) for x in data]
(
src_ids,
pos_ids,
Expand All @@ -234,31 +241,30 @@ def predict_dygraph(self, bert_config, data):
return pred_res

def predict_dygraph_jit(self, data):
with base.dygraph.guard(place):
bert = paddle.jit.load(self.model_save_prefix)
bert.eval()

(
src_ids,
pos_ids,
sent_ids,
input_mask,
mask_label,
mask_pos,
labels,
) = data
pred_res = bert(
src_ids,
pos_ids,
sent_ids,
input_mask,
mask_label,
mask_pos,
labels,
)
pred_res = [var.numpy() for var in pred_res]
bert = paddle.jit.load(self.model_save_prefix)
bert.eval()

(
src_ids,
pos_ids,
sent_ids,
input_mask,
mask_label,
mask_pos,
labels,
) = data
pred_res = bert(
src_ids,
pos_ids,
sent_ids,
input_mask,
mask_label,
mask_pos,
labels,
)
pred_res = [var.numpy() for var in pred_res]

return pred_res
return pred_res

def predict_analysis_inference(self, data):
output = PredictorTools(
Expand All @@ -267,6 +273,7 @@ def predict_analysis_inference(self, data):
out = output()
return out

@test_legacy_and_pt_and_pir
def test_train(self):
static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader
Expand All @@ -280,6 +287,7 @@ def test_train(self):
self.verify_predict()

@test_sot_only
@test_legacy_and_pt_and_pir
def test_train_composite(self):
core._set_prim_backward_enabled(True)
# core._add_skip_comp_ops("layer_norm")
Expand All @@ -297,43 +305,45 @@ def test_train_composite(self):
def verify_predict(self):
for data in self.data_reader.data_generator()():
dygraph_pred_res = self.predict_dygraph(self.bert_config, data)
static_pred_res = self.predict_static(data)
dygraph_jit_pred_res = self.predict_dygraph_jit(data)
predictor_pred_res = self.predict_analysis_inference(data)

for dy_res, st_res, dy_jit_res, predictor_res in zip(
dygraph_pred_res,
static_pred_res,
dygraph_jit_pred_res,
predictor_pred_res,
):
np.testing.assert_allclose(
st_res,
dy_res,
rtol=1e-05,
err_msg='dygraph_res: {},\n static_res: {}'.format(
dy_res[~np.isclose(st_res, dy_res)],
st_res[~np.isclose(st_res, dy_res)],
),
)
np.testing.assert_allclose(
st_res,
dy_jit_res,
rtol=1e-05,
err_msg='dygraph_jit_res: {},\n static_res: {}'.format(
dy_jit_res[~np.isclose(st_res, dy_jit_res)],
st_res[~np.isclose(st_res, dy_jit_res)],
),
)
np.testing.assert_allclose(
st_res,
predictor_res,
rtol=1e-05,
err_msg='dygraph_jit_res: {},\n static_res: {}'.format(
predictor_res[~np.isclose(st_res, predictor_res)],
st_res[~np.isclose(st_res, predictor_res)],
),
)
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
static_pred_res = self.predict_static(data)
dygraph_jit_pred_res = self.predict_dygraph_jit(data)
predictor_pred_res = self.predict_analysis_inference(data)

for dy_res, st_res, dy_jit_res, predictor_res in zip(
dygraph_pred_res,
static_pred_res,
dygraph_jit_pred_res,
predictor_pred_res,
):
np.testing.assert_allclose(
st_res,
dy_res,
rtol=1e-05,
err_msg='dygraph_res: {},\n static_res: {}'.format(
dy_res[~np.isclose(st_res, dy_res)],
st_res[~np.isclose(st_res, dy_res)],
),
)
np.testing.assert_allclose(
st_res,
dy_jit_res,
rtol=1e-05,
err_msg='dygraph_jit_res: {},\n static_res: {}'.format(
dy_jit_res[~np.isclose(st_res, dy_jit_res)],
st_res[~np.isclose(st_res, dy_jit_res)],
),
)
np.testing.assert_allclose(
st_res,
predictor_res,
rtol=1e-05,
err_msg='dygraph_jit_res: {},\n static_res: {}'.format(
predictor_res[~np.isclose(st_res, predictor_res)],
st_res[~np.isclose(st_res, predictor_res)],
),
)
break


Expand Down

0 comments on commit e98fe90

Please sign in to comment.