diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index c3d603dadd..9bdc80195f 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -484,7 +484,7 @@ def collect_single_finetune_params( if i != "_extra_state" and f".{_model_key}." in i ] for item_key in target_keys: - if _new_fitting and ".fitting_net." in item_key: + if _new_fitting and (".descriptor." not in item_key): # print(f'Keep {item_key} in old model!') _new_state_dict[item_key] = ( _random_state_dict[item_key].clone().detach() diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 0833200d47..fa9e5c138a 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -448,5 +448,73 @@ def tearDown(self) -> None: DPTrainTest.tearDown(self) +class TestPropFintuFromEnerModel(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_dpa1) + self.config["model"]["type_map"] = ["H", "C", "N", "O"] + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + property_input = str(Path(__file__).parent / "property/input.json") + with open(property_input) as f: + self.config_property = json.load(f) + prop_data_file = [str(Path(__file__).parent / "property/single")] + self.config_property["training"]["training_data"]["systems"] = prop_data_file + self.config_property["training"]["validation_data"]["systems"] = prop_data_file + self.config_property["model"]["descriptor"] = deepcopy(model_dpa1["descriptor"]) + self.config_property["training"]["numb_steps"] = 1 + self.config_property["training"]["save_freq"] = 1 + + def test_dp_train(self): + # test training from scratch + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + state_dict_trained = trainer.wrapper.model.state_dict() + + # test fine-tuning using diffferent fitting_net, here using property fitting + finetune_model = self.config["training"].get("save_ckpt", "model.ckpt") + ".pt" + self.config_property["model"], finetune_links = get_finetune_rules( + finetune_model, + self.config_property["model"], + model_branch="RANDOM", + ) + trainer_finetune = get_trainer( + deepcopy(self.config_property), + finetune_model=finetune_model, + finetune_links=finetune_links, + ) + + # check parameters + state_dict_finetuned = trainer_finetune.wrapper.model.state_dict() + for state_key in state_dict_finetuned: + if ( + "out_bias" not in state_key + and "out_std" not in state_key + and "fitting" not in state_key + ): + torch.testing.assert_close( + state_dict_trained[state_key], + state_dict_finetuned[state_key], + ) + + # check running + trainer_finetune.run() + + def tearDown(self): + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f) + + if __name__ == "__main__": unittest.main()