Skip to content

Commit

Permalink
feat(pt): support fine-tuning from random fitting (deepmodeling#3914)
Browse files Browse the repository at this point in the history
Support fine-tuning from random fitting in single-from-single
fine-tuning.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced the ability to randomly initialize fitting nets for
fine-tuning by setting the model branch to "RANDOM".
- Added support for "RANDOM" model branches in multitask pre-trained
model scenarios.

- **Documentation**
- Updated fine-tuning documentation to include details on handling
fitting net weights and using the `--model-branch RANDOM` parameter.

- **Tests**
- Added tests to verify the random initialization of fitting nets and
parameter closeness checks during fine-tuning.
  - Updated test cases to include "RANDOM" model branches in the output.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored and Mathieu Taillefumier committed Sep 18, 2024
1 parent 2698dbf commit 7c35acb
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 14 deletions.
10 changes: 9 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ def train(FLAGS):
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])
# handle the special key
assert (
"RANDOM" not in config["model"]["model_dict"]
), "Model name can not be 'RANDOM' in multi-task mode!"

# update fine-tuning config
finetune_links = None
Expand Down Expand Up @@ -337,7 +341,11 @@ def show(FLAGS):
" The provided model does not meet this criterion."
)
model_branches = list(model_params["model_dict"].keys())
log.info(f"Available model branches are {model_branches}")
model_branches += ["RANDOM"]
log.info(
f"Available model branches are {model_branches}, "
f"where 'RANDOM' means using a randomly initialized fitting net."
)
if "type-map" in FLAGS.ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
Expand Down
19 changes: 14 additions & 5 deletions deepmd/pt/utils/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ def get_finetune_rule_single(

if not from_multitask:
single_config_chosen = deepcopy(_model_param_pretrained)
if model_branch_from == "RANDOM":
# not ["", "RANDOM"], because single-from-single finetune uses pretrained fitting in default
new_fitting = True
else:
model_dict_params = _model_param_pretrained["model_dict"]
if model_branch_from == "":
if model_branch_from in ["", "RANDOM"]:
model_branch_chosen = next(iter(model_dict_params.keys()))
new_fitting = True
log.warning(
Expand Down Expand Up @@ -164,21 +167,27 @@ def get_finetune_rules(
pretrained_keys = last_model_params["model_dict"].keys()
for model_key in target_keys:
resuming = False
if "finetune_head" in model_config["model_dict"][model_key]:
if (
"finetune_head" in model_config["model_dict"][model_key]
and model_config["model_dict"][model_key]["finetune_head"] != "RANDOM"
):
pretrained_key = model_config["model_dict"][model_key]["finetune_head"]
assert pretrained_key in pretrained_keys, (
f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model!"
f"Available heads are: {list(pretrained_keys)}"
)
model_branch_from = pretrained_key
elif model_key in pretrained_keys:
elif (
"finetune_head" not in model_config["model_dict"][model_key]
and model_key in pretrained_keys
):
# not do anything if not defined "finetune_head" in heads that exist in the pretrained model
# this will just do resuming
model_branch_from = model_key
resuming = True
else:
# if not defined "finetune_head" in new heads, the fitting net will bre randomly initialized
model_branch_from = ""
# if not defined "finetune_head" in new heads or "finetune_head" is "RANDOM", the fitting net will bre randomly initialized
model_branch_from = "RANDOM"
model_config["model_dict"][model_key], finetune_rule = (
get_finetune_rule_single(
model_config["model_dict"][model_key],
Expand Down
15 changes: 9 additions & 6 deletions doc/train/finetuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,12 @@ The command for this operation is:
$ dp --pt train input.json --finetune pretrained.pt
```

:::{note}
We do not support fine-tuning from a randomly initialized fitting net in this case, which is the same as implementations in TensorFlow.
:::
In this case, it is important to note that the fitting net weights, except the energy bias, will be automatically set to those in the pre-trained model. This default setting is consistent with the implementations in TensorFlow.
If you wish to conduct fine-tuning using a randomly initialized fitting net in this scenario, you can manually adjust the `--model-branch` parameter to "RANDOM":

```bash
$ dp --pt train input.json --finetune pretrained.pt --model-branch RANDOM
```

The model section in input.json **must be the same as that in the pretrained model**.
If you do not know the model params in the pretrained model, you can add `--use-pretrain-script` in the fine-tuning command:
Expand All @@ -91,9 +94,9 @@ The model section will be overwritten (except the `type_map` subsection) by that

#### Fine-tuning from a multi-task pre-trained model

Additionally, within the PyTorch implementation and leveraging the flexibility offered by the framework and the multi-task training capabilities provided by DPA2,
Additionally, within the PyTorch implementation and leveraging the flexibility offered by the framework and the multi-task training process proposed in DPA2 [paper](https://arxiv.org/abs/2312.15492),
we also support more general multitask pre-trained models, which includes multiple datasets for pre-training. These pre-training datasets share a common descriptor while maintaining their individual fitting nets,
as detailed in the DPA2 [paper](https://arxiv.org/abs/2312.15492).
as detailed in the paper above.

For fine-tuning using this multitask pre-trained model (`multitask_pretrained.pt`),
one can select a specific branch (e.g., `CHOOSEN_BRANCH`) included in `multitask_pretrained.pt` for fine-tuning with the following command:
Expand All @@ -112,7 +115,7 @@ $ dp --pt show multitask_pretrained.pt model-branch
:::

This command will start fine-tuning based on the pre-trained model's descriptor and the selected branch's fitting net.
If --model-branch is not set, a randomly initialized fitting net will be used.
If --model-branch is not set or set to "RANDOM", a randomly initialized fitting net will be used.

### Multi-task fine-tuning

Expand Down
6 changes: 5 additions & 1 deletion source/tests/pt/test_dp_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,11 @@ def test_checkpoint(self):
run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}")
results = f.getvalue().split("\n")[:-1]
assert "This is a multitask model" in results[-8]
assert "Available model branches are ['model_1', 'model_2']" in results[-7]
assert (
"Available model branches are ['model_1', 'model_2', 'RANDOM'], "
"where 'RANDOM' means using a randomly initialized fitting net."
in results[-7]
)
assert "The type_map of branch model_1 is ['O', 'H', 'B']" in results[-6]
assert "The type_map of branch model_2 is ['O', 'H', 'B']" in results[-5]
assert (
Expand Down
36 changes: 35 additions & 1 deletion source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_dp_train(self):
# test training from scratch
trainer = get_trainer(deepcopy(self.config))
trainer.run()
state_dict_trained = trainer.wrapper.model.state_dict()

# test fine-tuning using same input
finetune_model = self.config["training"].get("save_ckpt", "model.ckpt") + ".pt"
Expand All @@ -46,7 +47,6 @@ def test_dp_train(self):
finetune_model=finetune_model,
finetune_links=finetune_links,
)
trainer_finetune.run()

# test fine-tuning using empty input
self.config_empty = deepcopy(self.config)
Expand All @@ -64,7 +64,41 @@ def test_dp_train(self):
finetune_model=finetune_model,
finetune_links=finetune_links,
)

# test fine-tuning using random fitting
self.config["model"], finetune_links = get_finetune_rules(
finetune_model, self.config["model"], model_branch="RANDOM"
)
trainer_finetune_random = get_trainer(
deepcopy(self.config_empty),
finetune_model=finetune_model,
finetune_links=finetune_links,
)

# check parameters
state_dict_finetuned = trainer_finetune.wrapper.model.state_dict()
state_dict_finetuned_empty = trainer_finetune_empty.wrapper.model.state_dict()
state_dict_finetuned_random = trainer_finetune_random.wrapper.model.state_dict()
for state_key in state_dict_finetuned:
if "out_bias" not in state_key and "out_std" not in state_key:
torch.testing.assert_close(
state_dict_trained[state_key],
state_dict_finetuned[state_key],
)
torch.testing.assert_close(
state_dict_trained[state_key],
state_dict_finetuned_empty[state_key],
)
if "fitting_net" not in state_key:
torch.testing.assert_close(
state_dict_trained[state_key],
state_dict_finetuned_random[state_key],
)

# check running
trainer_finetune.run()
trainer_finetune_empty.run()
trainer_finetune_random.run()

def test_trainable(self):
fix_params = deepcopy(self.config)
Expand Down

0 comments on commit 7c35acb

Please sign in to comment.