Skip to content

Commit

Permalink
pt: support multitask finetune (#3480)
Browse files Browse the repository at this point in the history
This PR:
1. merge `change_energy_bias` into `compute_output_stats` and reformat
it into `change_out_bias` of `model` level.
2. support single-task/multi-task finetuning from single-task/multi-task
pretrained model.

Need fix in future PR:
1. Finetuned model has covered `type_map`. (If fixed, `change_out_bias`
func will not need the input params `origin_type_map` and
`full_type_map`.) See also #3455.
2. `change_out_bias` support for other models.(e.g. Spin, ZBL, Polar,
Dipole and Dos.)
  • Loading branch information
iProzd authored Mar 22, 2024
1 parent fb61efb commit e47478f
Show file tree
Hide file tree
Showing 23 changed files with 918 additions and 358 deletions.
21 changes: 21 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,27 @@ def mixed_types(self) -> bool:
"""
return self.descriptor.mixed_types()

def set_out_bias(self, out_bias: np.ndarray, add=False) -> None:
"""
Modify the output bias for the atomic model.
Parameters
----------
out_bias : np.ndarray
The new bias to be applied.
add : bool, optional
Whether to add the new bias to the existing one.
If False, the output bias will be directly replaced by the new bias.
If True, the new bias will be added to the existing one.
"""
self.fitting["bias_atom_e"] = (
out_bias + self.fitting["bias_atom_e"] if add else out_bias
)

def get_out_bias(self) -> np.ndarray:
"""Return the output bias of the atomic model."""
return self.fitting["bias_atom_e"]

def forward_atomic(
self,
extended_coord: np.ndarray,
Expand Down
21 changes: 21 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,27 @@ def get_sel_type(self) -> List[int]:
# join all the selected types
return list(set().union(*[model.get_sel_type() for model in self.models]))

def set_out_bias(self, out_bias: np.ndarray, add=False) -> None:
"""
Modify the output bias for all the models in the linear atomic model.
Parameters
----------
out_bias : torch.Tensor
The new bias to be applied.
add : bool, optional
Whether to add the new bias to the existing one.
If False, the output bias will be directly replaced by the new bias.
If True, the new bias will be added to the existing one.
"""
for model in self.models:
model.set_out_bias(out_bias, add=add)

def get_out_bias(self) -> np.ndarray:
"""Return the weighted output bias of the linear atomic model."""
# TODO add get_out_bias for linear atomic model
raise NotImplementedError

def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).
Expand Down
19 changes: 19 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,25 @@ def get_sel_type(self) -> List[int]:
If returning an empty list, all atom types are selected.
"""

@abstractmethod
def set_out_bias(self, out_bias: t_tensor, add=False) -> None:
"""
Modify the output bias for the atomic model.
Parameters
----------
out_bias : t_tensor
The new bias to be applied.
add : bool, optional
Whether to add the new bias to the existing one.
If False, the output bias will be directly replaced by the new bias.
If True, the new bias will be added to the existing one.
"""

@abstractmethod
def get_out_bias(self) -> t_tensor:
"""Return the output bias of the atomic model."""

@abstractmethod
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).
Expand Down
19 changes: 19 additions & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,25 @@ def mixed_types(self) -> bool:
# to match DPA1 and DPA2.
return True

def set_out_bias(self, out_bias: np.ndarray, add=False) -> None:
"""
Modify the output bias for the atomic model.
Parameters
----------
out_bias : torch.Tensor
The new bias to be applied.
add : bool, optional
Whether to add the new bias to the existing one.
If False, the output bias will be directly replaced by the new bias.
If True, the new bias will be added to the existing one.
"""
self.bias_atom_e = out_bias + self.bias_atom_e if add else out_bias

def get_out_bias(self) -> np.ndarray:
"""Return the output bias of the atomic model."""
return self.bias_atom_e

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
15 changes: 8 additions & 7 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ def get_trainer(
dist.init_process_group(backend="nccl")

ckpt = init_model if init_model is not None else restart_model
config["model"] = change_finetune_model_params(
ckpt,
finetune_model,
config["model"],
multi_task=multi_task,
model_branch=model_branch,
)
finetune_links = None
if finetune_model is not None:
config["model"], finetune_links = change_finetune_model_params(
finetune_model,
config["model"],
model_branch=model_branch,
)
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)

def prepare_trainer_input_single(
Expand Down Expand Up @@ -194,6 +194,7 @@ def prepare_trainer_input_single(
finetune_model=finetune_model,
force_load=force_load,
shared_links=shared_links,
finetune_links=finetune_links,
init_frz_model=init_frz_model,
)
return trainer
Expand Down
107 changes: 107 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


import logging
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
)

import numpy as np
import torch

from deepmd.dpmodel.atomic_model import (
Expand All @@ -21,10 +24,21 @@
AtomExcludeMask,
PairExcludeMask,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.path import (
DPPath,
)

log = logging.getLogger(__name__)

BaseAtomicModel_ = make_base_atomic_model(torch.Tensor)


Expand Down Expand Up @@ -176,6 +190,40 @@ def serialize(self) -> dict:
"pair_exclude_types": self.pair_exclude_types,
}

def get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""
model_output_type = list(self.atomic_output_def().keys())
if "mask" in model_output_type:
model_output_type.pop(model_output_type.index("mask"))
out_name = model_output_type[0]

def model_forward(coord, atype, box, fparam=None, aparam=None):
with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
mixed_types=self.mixed_types(),
box=box,
)
atomic_ret = self.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
return atomic_ret[out_name].detach()

return model_forward

def compute_or_load_stat(
self,
sampled_func,
Expand All @@ -197,3 +245,62 @@ def compute_or_load_stat(
The path to the statistics files.
"""
raise NotImplementedError

def change_out_bias(
self,
merged,
origin_type_map,
full_type_map,
bias_adjust_mode="change-by-statistic",
) -> None:
"""Change the output bias according to the input data and the pretrained model.
Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
origin_type_map : List[str]
The original type_map in dataset, they are targets to change the output bias.
full_type_map : List[str]
The full type_map in pre-trained model
bias_adjust_mode : str
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
'change-by-statistic' : perform predictions on labels of target dataset,
and do least square on the errors to obtain the target shift as bias.
'set-by-statistic' : directly use the statistic output bias in the target dataset.
"""
sorter = np.argsort(full_type_map)
missing_types = [t for t in origin_type_map if t not in full_type_map]
assert (
not missing_types
), f"Some types are not in the pre-trained model: {list(missing_types)} !"
idx_type_map = sorter[
np.searchsorted(full_type_map, origin_type_map, sorter=sorter)
]
original_bias = self.get_out_bias()
if bias_adjust_mode == "change-by-statistic":
delta_bias = compute_output_stats(
merged,
self.get_ntypes(),
model_forward=self.get_forward_wrapper_func(),
)
self.set_out_bias(delta_bias, add=True)
elif bias_adjust_mode == "set-by-statistic":
bias_atom = compute_output_stats(
merged,
self.get_ntypes(),
)
self.set_out_bias(bias_atom)
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
bias_atom = self.get_out_bias()
log.info(
f"Change output bias of {origin_type_map!s} "
f"from {to_numpy_array(original_bias[idx_type_map]).reshape(-1)!s} "
f"to {to_numpy_array(bias_atom[idx_type_map]).reshape(-1)!s}."
)
21 changes: 21 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,27 @@ def wrapped_sampler():
if self.fitting_net is not None:
self.fitting_net.compute_output_stats(wrapped_sampler, stat_file_path)

def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None:
"""
Modify the output bias for the atomic model.
Parameters
----------
out_bias : torch.Tensor
The new bias to be applied.
add : bool, optional
Whether to add the new bias to the existing one.
If False, the output bias will be directly replaced by the new bias.
If True, the new bias will be added to the existing one.
"""
self.fitting_net["bias_atom_e"] = (
out_bias + self.fitting_net["bias_atom_e"] if add else out_bias
)

def get_out_bias(self) -> torch.Tensor:
"""Return the output bias of the atomic model."""
return self.fitting_net["bias_atom_e"]

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.fitting_net.get_dim_fparam()
Expand Down
25 changes: 21 additions & 4 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,27 @@ def _compute_weight(
for _ in range(nmodels)
]

def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None:
"""
Modify the output bias for all the models in the linear atomic model.
Parameters
----------
out_bias : torch.Tensor
The new bias to be applied.
add : bool, optional
Whether to add the new bias to the existing one.
If False, the output bias will be directly replaced by the new bias.
If True, the new bias will be added to the existing one.
"""
for model in self.models:
model.set_out_bias(out_bias, add=add)

def get_out_bias(self) -> torch.Tensor:
"""Return the weighted output bias of the linear atomic model."""
# TODO add get_out_bias for linear atomic model
raise NotImplementedError

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
# tricky...
Expand Down Expand Up @@ -390,10 +411,6 @@ def compute_or_load_stat(
self.models[0].compute_or_load_stat(sampled_func, stat_file_path)
self.models[1].compute_or_load_stat(sampled_func, stat_file_path)

def change_energy_bias(self):
# need to implement
pass

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
21 changes: 18 additions & 3 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,24 @@ def compute_or_load_stat(
torch.tensor(bias_atom_e, device=env.DEVICE).view([self.ntypes, 1])
)

def change_energy_bias(self) -> None:
# need to implement
pass
def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None:
"""
Modify the output bias for the atomic model.
Parameters
----------
out_bias : torch.Tensor
The new bias to be applied.
add : bool, optional
Whether to add the new bias to the existing one.
If False, the output bias will be directly replaced by the new bias.
If True, the new bias will be added to the existing one.
"""
self.bias_atom_e = out_bias + self.bias_atom_e if add else out_bias

def get_out_bias(self) -> torch.Tensor:
"""Return the output bias of the atomic model."""
return self.bias_atom_e

def forward_atomic(
self,
Expand Down
Loading

0 comments on commit e47478f

Please sign in to comment.