From 03e9a3aecceeb4c574b5383802c24dd74fd45384 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 27 Apr 2024 04:37:46 -0400 Subject: [PATCH 1/3] tests: move init_models to setUpModule so that the setup timing will be shown correctly Signed-off-by: Jinzhe Zeng --- source/tests/pt/test_finetune.py | 79 ++++++++++--------- source/tests/pt/test_multitask.py | 9 ++- .../tests/tf/test_model_compression_se_a.py | 12 ++- .../tf/test_model_compression_se_a_ebd.py | 12 ++- ...odel_compression_se_a_ebd_type_one_side.py | 4 +- .../tf/test_model_compression_se_atten.py | 12 ++- .../tests/tf/test_model_compression_se_r.py | 4 +- .../tests/tf/test_model_compression_se_t.py | 4 +- 8 files changed, 86 insertions(+), 50 deletions(-) diff --git a/source/tests/pt/test_finetune.py b/source/tests/pt/test_finetune.py index a874d35497..60a356b538 100644 --- a/source/tests/pt/test_finetune.py +++ b/source/tests/pt/test_finetune.py @@ -33,44 +33,47 @@ model_zbl, ) -energy_data_requirement = [ - DataRequirementItem( - "energy", - ndof=1, - atomic=False, - must=False, - high_prec=True, - ), - DataRequirementItem( - "force", - ndof=3, - atomic=True, - must=False, - high_prec=False, - ), - DataRequirementItem( - "virial", - ndof=9, - atomic=False, - must=False, - high_prec=False, - ), - DataRequirementItem( - "atom_ener", - ndof=1, - atomic=True, - must=False, - high_prec=False, - ), - DataRequirementItem( - "atom_pref", - ndof=1, - atomic=True, - must=False, - high_prec=False, - repeat=3, - ), -] + +def setUpModule(): + global energy_data_requirement + energy_data_requirement = [ + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + must=False, + high_prec=True, + ), + DataRequirementItem( + "force", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "virial", + ndof=9, + atomic=False, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_ener", + ndof=1, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_pref", + ndof=1, + atomic=True, + must=False, + high_prec=False, + repeat=3, + ), + ] class FinetuneTest: diff --git a/source/tests/pt/test_multitask.py b/source/tests/pt/test_multitask.py index 3c78484e1f..08b632a2e4 100644 --- a/source/tests/pt/test_multitask.py +++ b/source/tests/pt/test_multitask.py @@ -25,9 +25,12 @@ model_se_e2_a, ) -multitask_template_json = str(Path(__file__).parent / "water/multitask.json") -with open(multitask_template_json) as f: - multitask_template = json.load(f) + +def setUpModule(): + global multitask_template + multitask_template_json = str(Path(__file__).parent / "water/multitask.json") + with open(multitask_template_json) as f: + multitask_template = json.load(f) class MultiTaskTrainTest: diff --git a/source/tests/tf/test_model_compression_se_a.py b/source/tests/tf/test_model_compression_se_a.py index 4e49dd44e0..60655074ca 100644 --- a/source/tests/tf/test_model_compression_se_a.py +++ b/source/tests/tf/test_model_compression_se_a.py @@ -73,8 +73,16 @@ def _init_models_exclude_types(): return INPUT, frozen_model, compressed_model -INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() -INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types() +def setUpModule(): + global \ + INPUT, \ + FROZEN_MODEL, \ + COMPRESSED_MODEL, \ + INPUT_ET, \ + FROZEN_MODEL_ET, \ + COMPRESSED_MODEL_ET + INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() + INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types() class TestDeepPotAPBC(unittest.TestCase): diff --git a/source/tests/tf/test_model_compression_se_a_ebd.py b/source/tests/tf/test_model_compression_se_a_ebd.py index debae1f0ba..1864a5196f 100644 --- a/source/tests/tf/test_model_compression_se_a_ebd.py +++ b/source/tests/tf/test_model_compression_se_a_ebd.py @@ -85,8 +85,16 @@ def _init_models_exclude_types(): return INPUT, frozen_model, compressed_model -INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() -INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types() +def setUpModule(): + global \ + INPUT, \ + FROZEN_MODEL, \ + COMPRESSED_MODEL, \ + INPUT_ET, \ + FROZEN_MODEL_ET, \ + COMPRESSED_MODEL_ET + INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() + INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types() class TestDeepPotAPBC(unittest.TestCase): diff --git a/source/tests/tf/test_model_compression_se_a_ebd_type_one_side.py b/source/tests/tf/test_model_compression_se_a_ebd_type_one_side.py index a24bf48398..e0a9913242 100644 --- a/source/tests/tf/test_model_compression_se_a_ebd_type_one_side.py +++ b/source/tests/tf/test_model_compression_se_a_ebd_type_one_side.py @@ -85,7 +85,9 @@ def _init_models_exclude_types(): return INPUT, frozen_model, compressed_model -INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() +def setUpModule(): + global INPUT, FROZEN_MODEL, COMPRESSED_MODEL + INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() class TestDeepPotAPBC(unittest.TestCase): diff --git a/source/tests/tf/test_model_compression_se_atten.py b/source/tests/tf/test_model_compression_se_atten.py index 03ddedad39..5775725adf 100644 --- a/source/tests/tf/test_model_compression_se_atten.py +++ b/source/tests/tf/test_model_compression_se_atten.py @@ -154,8 +154,16 @@ def _init_models_exclude_types(): return inputs, frozen_models, compressed_models -INPUTS, FROZEN_MODELS, COMPRESSED_MODELS = _init_models() -INPUTS_ET, FROZEN_MODELS_ET, COMPRESSED_MODELS_ET = _init_models_exclude_types() +def setUpModule(): + global \ + INPUTS, \ + FROZEN_MODELS, \ + COMPRESSED_MODELS, \ + INPUTS_ET, \ + FROZEN_MODELS_ET, \ + COMPRESSED_MODELS_ET + INPUTS, FROZEN_MODELS, COMPRESSED_MODELS = _init_models() + INPUTS_ET, FROZEN_MODELS_ET, COMPRESSED_MODELS_ET = _init_models_exclude_types() def _get_default_places(nth_test): diff --git a/source/tests/tf/test_model_compression_se_r.py b/source/tests/tf/test_model_compression_se_r.py index 26665e5354..324ee248a4 100644 --- a/source/tests/tf/test_model_compression_se_r.py +++ b/source/tests/tf/test_model_compression_se_r.py @@ -60,7 +60,9 @@ def _init_models(): return INPUT, frozen_model, compressed_model -INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() +def setUpModule(): + global INPUT, FROZEN_MODEL, COMPRESSED_MODEL + INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() class TestDeepPotAPBC(unittest.TestCase): diff --git a/source/tests/tf/test_model_compression_se_t.py b/source/tests/tf/test_model_compression_se_t.py index ec68176cdb..8c23196535 100644 --- a/source/tests/tf/test_model_compression_se_t.py +++ b/source/tests/tf/test_model_compression_se_t.py @@ -60,7 +60,9 @@ def _init_models(): return INPUT, frozen_model, compressed_model -INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() +def setUpModule(): + global INPUT, FROZEN_MODEL, COMPRESSED_MODEL + INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() def tearDownModule(): From d268015832bce3547b2c1482ddef878a6d9215ac Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 27 Apr 2024 04:45:02 -0400 Subject: [PATCH 2/3] revert test_finetune.py Signed-off-by: Jinzhe Zeng --- source/tests/pt/test_finetune.py | 79 +++++++++++++++----------------- 1 file changed, 38 insertions(+), 41 deletions(-) diff --git a/source/tests/pt/test_finetune.py b/source/tests/pt/test_finetune.py index 60a356b538..a874d35497 100644 --- a/source/tests/pt/test_finetune.py +++ b/source/tests/pt/test_finetune.py @@ -33,47 +33,44 @@ model_zbl, ) - -def setUpModule(): - global energy_data_requirement - energy_data_requirement = [ - DataRequirementItem( - "energy", - ndof=1, - atomic=False, - must=False, - high_prec=True, - ), - DataRequirementItem( - "force", - ndof=3, - atomic=True, - must=False, - high_prec=False, - ), - DataRequirementItem( - "virial", - ndof=9, - atomic=False, - must=False, - high_prec=False, - ), - DataRequirementItem( - "atom_ener", - ndof=1, - atomic=True, - must=False, - high_prec=False, - ), - DataRequirementItem( - "atom_pref", - ndof=1, - atomic=True, - must=False, - high_prec=False, - repeat=3, - ), - ] +energy_data_requirement = [ + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + must=False, + high_prec=True, + ), + DataRequirementItem( + "force", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "virial", + ndof=9, + atomic=False, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_ener", + ndof=1, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_pref", + ndof=1, + atomic=True, + must=False, + high_prec=False, + repeat=3, + ), +] class FinetuneTest: From e4e07ce9079ab9fcc1d09d0d955a37673353ac76 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 6 May 2024 04:19:20 -0400 Subject: [PATCH 3/3] disable_mixed_precision_graph_rewrite Signed-off-by: Jinzhe Zeng --- deepmd/tf/common.py | 9 +++++++++ source/tests/tf/test_mixed_prec_training.py | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/deepmd/tf/common.py b/deepmd/tf/common.py index 5f2d0d882e..06be22a2ee 100644 --- a/deepmd/tf/common.py +++ b/deepmd/tf/common.py @@ -13,6 +13,9 @@ ) import tensorflow +from packaging.version import ( + Version, +) from tensorflow.python.framework import ( tensor_util, ) @@ -31,6 +34,7 @@ ) from deepmd.tf.env import ( GLOBAL_TF_FLOAT_PRECISION, + TF_VERSION, op_module, tf, ) @@ -289,3 +293,8 @@ def clear_session(): tf.reset_default_graph() # TODO: remove this line when data_requirement is not a global variable data_requirement.clear() + _TF_VERSION = Version(TF_VERSION) + if _TF_VERSION < Version("2.4.0"): + tf.train.experimental.disable_mixed_precision_graph_rewrite() + else: + tf.mixed_precision.disable_mixed_precision_graph_rewrite() diff --git a/source/tests/tf/test_mixed_prec_training.py b/source/tests/tf/test_mixed_prec_training.py index 4a4021771d..b43ad7ce2f 100644 --- a/source/tests/tf/test_mixed_prec_training.py +++ b/source/tests/tf/test_mixed_prec_training.py @@ -8,6 +8,9 @@ Version, ) +from deepmd.tf.common import ( + clear_session, +) from deepmd.tf.env import ( TF_VERSION, ) @@ -61,3 +64,4 @@ def tearDown(self): _file_delete("model.ckpt-100.data-00000-of-00001") _file_delete("input_v2_compat.json") _file_delete("lcurve.out") + clear_session()