diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 8e8aab939e..f7e2f4c9f6 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -229,6 +229,10 @@ def train(FLAGS): shared_links = None if multi_task: config["model"], shared_links = preprocess_shared_params(config["model"]) + # handle the special key + assert ( + "RANDOM" not in config["model"]["model_dict"] + ), "Model name can not be 'RANDOM' in multi-task mode!" # update fine-tuning config finetune_links = None @@ -337,7 +341,11 @@ def show(FLAGS): " The provided model does not meet this criterion." ) model_branches = list(model_params["model_dict"].keys()) - log.info(f"Available model branches are {model_branches}") + model_branches += ["RANDOM"] + log.info( + f"Available model branches are {model_branches}, " + f"where 'RANDOM' means using a randomly initialized fitting net." + ) if "type-map" in FLAGS.ATTRIBUTES: if model_is_multi_task: model_branches = list(model_params["model_dict"].keys()) diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index 74f01fc2ea..2dd2230b54 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -30,9 +30,12 @@ def get_finetune_rule_single( if not from_multitask: single_config_chosen = deepcopy(_model_param_pretrained) + if model_branch_from == "RANDOM": + # not ["", "RANDOM"], because single-from-single finetune uses pretrained fitting in default + new_fitting = True else: model_dict_params = _model_param_pretrained["model_dict"] - if model_branch_from == "": + if model_branch_from in ["", "RANDOM"]: model_branch_chosen = next(iter(model_dict_params.keys())) new_fitting = True log.warning( @@ -164,21 +167,27 @@ def get_finetune_rules( pretrained_keys = last_model_params["model_dict"].keys() for model_key in target_keys: resuming = False - if "finetune_head" in model_config["model_dict"][model_key]: + if ( + "finetune_head" in model_config["model_dict"][model_key] + and model_config["model_dict"][model_key]["finetune_head"] != "RANDOM" + ): pretrained_key = model_config["model_dict"][model_key]["finetune_head"] assert pretrained_key in pretrained_keys, ( f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model!" f"Available heads are: {list(pretrained_keys)}" ) model_branch_from = pretrained_key - elif model_key in pretrained_keys: + elif ( + "finetune_head" not in model_config["model_dict"][model_key] + and model_key in pretrained_keys + ): # not do anything if not defined "finetune_head" in heads that exist in the pretrained model # this will just do resuming model_branch_from = model_key resuming = True else: - # if not defined "finetune_head" in new heads, the fitting net will bre randomly initialized - model_branch_from = "" + # if not defined "finetune_head" in new heads or "finetune_head" is "RANDOM", the fitting net will bre randomly initialized + model_branch_from = "RANDOM" model_config["model_dict"][model_key], finetune_rule = ( get_finetune_rule_single( model_config["model_dict"][model_key], diff --git a/doc/train/finetuning.md b/doc/train/finetuning.md index 1cd88191d2..4fbe95b2fd 100644 --- a/doc/train/finetuning.md +++ b/doc/train/finetuning.md @@ -68,9 +68,12 @@ The command for this operation is: $ dp --pt train input.json --finetune pretrained.pt ``` -:::{note} -We do not support fine-tuning from a randomly initialized fitting net in this case, which is the same as implementations in TensorFlow. -::: +In this case, it is important to note that the fitting net weights, except the energy bias, will be automatically set to those in the pre-trained model. This default setting is consistent with the implementations in TensorFlow. +If you wish to conduct fine-tuning using a randomly initialized fitting net in this scenario, you can manually adjust the `--model-branch` parameter to "RANDOM": + +```bash +$ dp --pt train input.json --finetune pretrained.pt --model-branch RANDOM +``` The model section in input.json **must be the same as that in the pretrained model**. If you do not know the model params in the pretrained model, you can add `--use-pretrain-script` in the fine-tuning command: @@ -91,9 +94,9 @@ The model section will be overwritten (except the `type_map` subsection) by that #### Fine-tuning from a multi-task pre-trained model -Additionally, within the PyTorch implementation and leveraging the flexibility offered by the framework and the multi-task training capabilities provided by DPA2, +Additionally, within the PyTorch implementation and leveraging the flexibility offered by the framework and the multi-task training process proposed in DPA2 [paper](https://arxiv.org/abs/2312.15492), we also support more general multitask pre-trained models, which includes multiple datasets for pre-training. These pre-training datasets share a common descriptor while maintaining their individual fitting nets, -as detailed in the DPA2 [paper](https://arxiv.org/abs/2312.15492). +as detailed in the paper above. For fine-tuning using this multitask pre-trained model (`multitask_pretrained.pt`), one can select a specific branch (e.g., `CHOOSEN_BRANCH`) included in `multitask_pretrained.pt` for fine-tuning with the following command: @@ -112,7 +115,7 @@ $ dp --pt show multitask_pretrained.pt model-branch ::: This command will start fine-tuning based on the pre-trained model's descriptor and the selected branch's fitting net. -If --model-branch is not set, a randomly initialized fitting net will be used. +If --model-branch is not set or set to "RANDOM", a randomly initialized fitting net will be used. ### Multi-task fine-tuning diff --git a/source/tests/pt/test_dp_show.py b/source/tests/pt/test_dp_show.py index da5137d7ae..5d6cb9bd36 100644 --- a/source/tests/pt/test_dp_show.py +++ b/source/tests/pt/test_dp_show.py @@ -149,7 +149,11 @@ def test_checkpoint(self): run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}") results = f.getvalue().split("\n")[:-1] assert "This is a multitask model" in results[-8] - assert "Available model branches are ['model_1', 'model_2']" in results[-7] + assert ( + "Available model branches are ['model_1', 'model_2', 'RANDOM'], " + "where 'RANDOM' means using a randomly initialized fitting net." + in results[-7] + ) assert "The type_map of branch model_1 is ['O', 'H', 'B']" in results[-6] assert "The type_map of branch model_2 is ['O', 'H', 'B']" in results[-5] assert ( diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 2926465775..0833200d47 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -34,6 +34,7 @@ def test_dp_train(self): # test training from scratch trainer = get_trainer(deepcopy(self.config)) trainer.run() + state_dict_trained = trainer.wrapper.model.state_dict() # test fine-tuning using same input finetune_model = self.config["training"].get("save_ckpt", "model.ckpt") + ".pt" @@ -46,7 +47,6 @@ def test_dp_train(self): finetune_model=finetune_model, finetune_links=finetune_links, ) - trainer_finetune.run() # test fine-tuning using empty input self.config_empty = deepcopy(self.config) @@ -64,7 +64,41 @@ def test_dp_train(self): finetune_model=finetune_model, finetune_links=finetune_links, ) + + # test fine-tuning using random fitting + self.config["model"], finetune_links = get_finetune_rules( + finetune_model, self.config["model"], model_branch="RANDOM" + ) + trainer_finetune_random = get_trainer( + deepcopy(self.config_empty), + finetune_model=finetune_model, + finetune_links=finetune_links, + ) + + # check parameters + state_dict_finetuned = trainer_finetune.wrapper.model.state_dict() + state_dict_finetuned_empty = trainer_finetune_empty.wrapper.model.state_dict() + state_dict_finetuned_random = trainer_finetune_random.wrapper.model.state_dict() + for state_key in state_dict_finetuned: + if "out_bias" not in state_key and "out_std" not in state_key: + torch.testing.assert_close( + state_dict_trained[state_key], + state_dict_finetuned[state_key], + ) + torch.testing.assert_close( + state_dict_trained[state_key], + state_dict_finetuned_empty[state_key], + ) + if "fitting_net" not in state_key: + torch.testing.assert_close( + state_dict_trained[state_key], + state_dict_finetuned_random[state_key], + ) + + # check running + trainer_finetune.run() trainer_finetune_empty.run() + trainer_finetune_random.run() def test_trainable(self): fix_params = deepcopy(self.config)