Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pt): finetuning property/dipole/polar/dos fitting with multi-dimensional data causes error #4145

Merged
merged 16 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,11 @@ def collect_single_finetune_params(
if i != "_extra_state" and f".{_model_key}." in i
]
for item_key in target_keys:
if _new_fitting and ".fitting_net." in item_key:
if _new_fitting and (
(".fitting_net." in item_key)
or (".out_bias" in item_key)
or (".out_std" in item_key)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
):
# print(f'Keep {item_key} in old model!')
_new_state_dict[item_key] = (
_random_state_dict[item_key].clone().detach()
Expand Down
68 changes: 68 additions & 0 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,5 +448,73 @@ def tearDown(self) -> None:
DPTrainTest.tearDown(self)


class TestPropFintuFromEnerModel(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
self.config = json.load(f)
data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.config["training"]["training_data"]["systems"] = data_file
self.config["training"]["validation_data"]["systems"] = data_file
self.config["model"] = deepcopy(model_dpa1)
self.config["model"]["type_map"] = ["H", "C", "N", "O"]
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1

property_input = str(Path(__file__).parent / "property/input.json")
with open(property_input) as f:
self.config_property = json.load(f)
prop_data_file = [str(Path(__file__).parent / "property/single")]
self.config_property["training"]["training_data"]["systems"] = prop_data_file
self.config_property["training"]["validation_data"]["systems"] = prop_data_file
self.config_property["model"]["descriptor"] = deepcopy(model_dpa1["descriptor"])
self.config_property["training"]["numb_steps"] = 1
self.config_property["training"]["save_freq"] = 1

def test_dp_train(self):
# test training from scratch
trainer = get_trainer(deepcopy(self.config))
trainer.run()
state_dict_trained = trainer.wrapper.model.state_dict()

# test fine-tuning using diffferent fitting_net, here using property fitting
finetune_model = self.config["training"].get("save_ckpt", "model.ckpt") + ".pt"
self.config_property["model"], finetune_links = get_finetune_rules(
finetune_model,
self.config_property["model"],
model_branch="RANDOM",
)
trainer_finetune = get_trainer(
deepcopy(self.config_property),
finetune_model=finetune_model,
finetune_links=finetune_links,
)

# check parameters
state_dict_finetuned = trainer_finetune.wrapper.model.state_dict()
for state_key in state_dict_finetuned:
if (
"out_bias" not in state_key
and "out_std" not in state_key
and "fitting" not in state_key
):
torch.testing.assert_close(
state_dict_trained[state_key],
state_dict_finetuned[state_key],
)

# check running
trainer_finetune.run()

def tearDown(self):
for f in os.listdir("."):
if f.startswith("model") and f.endswith(".pt"):
os.remove(f)
if f in ["lcurve.out"]:
os.remove(f)
if f in ["stat_files"]:
shutil.rmtree(f)


if __name__ == "__main__":
unittest.main()