Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: support multitask finetune #3480

Merged
merged 17 commits into from
Mar 22, 2024
Merged
15 changes: 8 additions & 7 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@
dist.init_process_group(backend="nccl")

ckpt = init_model if init_model is not None else restart_model
config["model"] = change_finetune_model_params(
ckpt,
finetune_model,
config["model"],
multi_task=multi_task,
model_branch=model_branch,
)
finetune_links = None
if finetune_model is not None:
config["model"], finetune_links = change_finetune_model_params(
finetune_model,

Check warning on line 96 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L93-L96

Added lines #L93 - L96 were not covered by tests
config["model"],
model_branch=model_branch,
)
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)

def prepare_trainer_input_single(
Expand Down Expand Up @@ -194,6 +194,7 @@
finetune_model=finetune_model,
force_load=force_load,
shared_links=shared_links,
finetune_links=finetune_links,
init_frz_model=init_frz_model,
)
return trainer
Expand Down
4 changes: 0 additions & 4 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,6 @@ def compute_or_load_stat(
self.models[0].compute_or_load_stat(sampled_func, stat_file_path)
self.models[1].compute_or_load_stat(sampled_func, stat_file_path)

def change_energy_bias(self):
# need to implement
pass

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 0 additions & 4 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,6 @@ def compute_or_load_stat(
torch.tensor(bias_atom_e, device=env.DEVICE).view([self.ntypes, 1])
)

def change_energy_bias(self) -> None:
# need to implement
pass

def forward_atomic(
self,
extended_coord: torch.Tensor,
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,8 @@
else:
model_predict = model_ret
return model_predict

def change_out_bias(
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
raise NotImplementedError

Check warning on line 99 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L99

Added line #L99 was not covered by tests
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/dos_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,8 @@
else:
model_predict = model_ret
return model_predict

def change_out_bias(

Check warning on line 82 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L82

Added line #L82 was not covered by tests
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
raise NotImplementedError

Check warning on line 85 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L85

Added line #L85 was not covered by tests
71 changes: 71 additions & 0 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
import tempfile

Check warning on line 4 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L2-L4

Added lines #L2 - L4 were not covered by tests
from typing import (
Dict,
Optional,
)

import numpy as np

Check warning on line 10 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L10

Added line #L10 was not covered by tests
import torch

from deepmd.infer.deep_eval import (

Check warning on line 13 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L13

Added line #L13 was not covered by tests
DeepEval,
)
from deepmd.pt.utils.stat import (

Check warning on line 16 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L16

Added line #L16 was not covered by tests
compute_output_stats,
)
from deepmd.pt.utils.utils import (

Check warning on line 19 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L19

Added line #L19 was not covered by tests
to_numpy_array,
)

from .dp_model import (
DPModel,
)

log = logging.getLogger(__name__)

Check warning on line 27 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L27

Added line #L27 was not covered by tests


class EnergyModel(DPModel):
model_type = "ener"
Expand Down Expand Up @@ -97,3 +113,58 @@
else:
model_predict = model_ret
return model_predict

def change_out_bias(

Check warning on line 117 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L117

Added line #L117 was not covered by tests
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
"""Change the energy bias according to the input data and the pretrained model.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
origin_type_map : List[str]
The original type_map in dataset, they are targets to change the energy bias.
full_type_map : List[str]
The full type_map in pre-trained model
bias_shift : str
The mode for changing energy bias : ['delta', 'statistic']
'delta' : perform predictions on energies of target dataset,
and do least sqaure on the errors to obtain the target shift as bias.
'statistic' : directly use the statistic energy bias in the target dataset.
"""
sorter = np.argsort(full_type_map)
idx_type_map = sorter[

Check warning on line 142 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L141-L142

Added lines #L141 - L142 were not covered by tests
np.searchsorted(full_type_map, origin_type_map, sorter=sorter)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
]
original_bias = self.get_fitting_net()["bias_atom_e"]
if bias_shift == "delta":
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
model = torch.jit.script(self)
torch.jit.save(model, tmp_model.name)
dp = DeepEval(tmp_model.name)
os.unlink(tmp_model.name)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
delta_bias_e = compute_output_stats(

Check warning on line 152 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L145-L152

Added lines #L145 - L152 were not covered by tests
merged,
self.atomic_model.get_ntypes(),
model=dp,
)
bias_atom_e = delta_bias_e + original_bias
elif bias_shift == "statistic":
bias_atom_e = compute_output_stats(

Check warning on line 159 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L157-L159

Added lines #L157 - L159 were not covered by tests
merged,
self.atomic_model.get_ntypes(),
)
else:
raise RuntimeError("Unknown bias_shift mode: " + bias_shift)
log.info(

Check warning on line 165 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L164-L165

Added lines #L164 - L165 were not covered by tests
f"Change energy bias of {origin_type_map!s} "
f"from {to_numpy_array(original_bias[idx_type_map]).reshape(-1)!s} "
f"to {to_numpy_array(bias_atom_e[idx_type_map]).reshape(-1)!s}."
)
self.get_fitting_net()["bias_atom_e"] = bias_atom_e

Check warning on line 170 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L170

Added line #L170 was not covered by tests
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,8 @@
else:
model_predict = model_ret
return model_predict

def change_out_bias(

Check warning on line 80 in deepmd/pt/model/model/polar_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/polar_model.py#L80

Added line #L80 was not covered by tests
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
raise NotImplementedError

Check warning on line 83 in deepmd/pt/model/model/polar_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/polar_model.py#L83

Added line #L83 was not covered by tests
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,8 @@
].squeeze(-2)
# not support virial by far
return model_predict

def change_out_bias(

Check warning on line 562 in deepmd/pt/model/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/spin_model.py#L562

Added line #L562 was not covered by tests
self, merged, origin_type_map, full_type_map, bias_shift="delta"
) -> None:
raise NotImplementedError

Check warning on line 565 in deepmd/pt/model/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/spin_model.py#L565

Added line #L565 was not covered by tests
78 changes: 0 additions & 78 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
import os
import tempfile
from abc import (
abstractmethod,
)
Expand All @@ -15,9 +13,6 @@
import numpy as np
import torch

from deepmd.infer.deep_eval import (
DeepEval,
)
from deepmd.pt.model.network.mlp import (
FittingNet,
NetworkCollection,
Expand All @@ -33,7 +28,6 @@
)
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
DEVICE,
PRECISION_DICT,
)
from deepmd.pt.utils.exclude_mask import (
Expand All @@ -43,12 +37,6 @@
to_numpy_array,
to_torch_tensor,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.finetune import (
change_energy_bias_lower,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE
Expand Down Expand Up @@ -88,72 +76,6 @@ def share_params(self, base_class, shared_level, resume=False):
else:
raise NotImplementedError

def change_energy_bias(
self,
config,
model,
old_type_map: List[str],
new_type_map: List[str],
bias_shift="delta",
ntest=10,
):
"""Change the energy bias according to the input data and the pretrained model.

Parameters
----------
config : Dict
The configuration.
model : EnergyModel
Energy model loaded pre-trained model.
new_type_map : List[str]
The original type_map in dataset, they are targets to change the energy bias.
old_type_map : List[str]
The full type_map in pretrained model
bias_shift : str
The mode for changing energy bias : ['delta', 'statistic']
'delta' : perform predictions on energies of target dataset,
and do least sqaure on the errors to obtain the target shift as bias.
'statistic' : directly use the statistic energy bias in the target dataset.
ntest : int
The number of test samples in a system to change the energy bias.
"""
log.info(
f"Changing energy bias in pretrained model for types {new_type_map!s}... "
"(this step may take long time)"
)
# data
systems = config["training"]["training_data"]["systems"]
finetune_data = DeepmdDataSystem(
systems=systems,
batch_size=config["training"]["training_data"].get("batch_size", "auto"),
test_size=1,
)
finetune_data.add("energy", ndof=1, atomic=False, must=True, high_prec=True)
model = torch.jit.script(model)
if model.get_dim_fparam() > 0:
finetune_data.add("fparam", model.get_dim_fparam(), atomic=False, must=True)
if model.get_dim_aparam() > 0:
finetune_data.add("aparam", model.get_dim_aparam(), atomic=True, must=True)
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
torch.jit.save(model, tmp_model.name)
dp = DeepEval(tmp_model.name)
os.unlink(tmp_model.name)
bias = change_energy_bias_lower(
finetune_data,
dp,
new_type_map,
old_type_map,
self.bias_atom_e.detach().cpu().numpy().reshape(-1),
bias_shift=bias_shift,
ntest=ntest,
)
self.bias_atom_e = (
torch.from_numpy(bias)
.type_as(self.bias_atom_e)
.reshape(self.bias_atom_e.shape)
.to(DEVICE)
)


class GeneralFitting(Fitting):
"""Construct a general fitting net.
Expand Down
Loading
Loading