Skip to content

Commit

Permalink
feat(pt): consistent fine-tuning with init-model (#3803)
Browse files Browse the repository at this point in the history
Fix #3747. Fix #3455. 

- Consistent fine-tuning with init-model, now in pt, fine-tuning include
three steps:
1. Change model params (for multitask fine-tuning, random fitting and
type-related params),
2. Init-model, 
3. Change bias

- By default, input will use user input while fine-tuning, instead of
being overwritten by that in the pre-trained model. When adding
“--use-pretrain-script”, user can use that in the pre-trained model.

- Now `type_map` will use that in the user input instead of overwritten
by that in the pre-trained model.

Note:
1. After discussed with @wanghan-iapcm, **behavior of fine-tuning in TF
is kept as before**. If needed in the future, it can be implemented
then.
2. Fine-tuning using DOSModel in PT need to be fixed. (an issue will be
opened, maybe fixed in another PR, cc @anyangml )

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added support for using model parameters from a pretrained model
script.
- Introduced new methods to handle type-related parameters and
fine-tuning configurations.

- **Documentation**
- Updated documentation to clarify the model section requirements and
the new `--use-pretrain-script` option for fine-tuning.

- **Refactor**
- Simplified and improved the readability of key functions related to
model training and fine-tuning.

- **Tests**
- Added new test methods and utility functions to ensure consistency of
type mapping and parameter updates.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Duo <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Han Wang <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jun 13, 2024
1 parent 6d378f4 commit a1a3840
Show file tree
Hide file tree
Showing 92 changed files with 3,014 additions and 496 deletions.
22 changes: 22 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
AtomExcludeMask,
PairExcludeMask,
)
from deepmd.utils.finetune import (
get_index_between_two_maps,
map_atom_exclude_types,
map_pair_exclude_types,
)

from .make_base_atomic_model import (
make_base_atomic_model,
Expand Down Expand Up @@ -113,6 +118,23 @@ def atomic_output_def(self) -> FittingOutputDef:
]
)

def change_type_map(
self, type_map: List[str], model_with_new_type_stat=None
) -> None:
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
"""
remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map)
self.type_map = type_map
self.reinit_atom_exclude(
map_atom_exclude_types(self.atom_exclude_types, remap_index)
)
self.reinit_pair_exclude(
map_pair_exclude_types(self.pair_exclude_types, remap_index)
)
self.out_bias = self.out_bias[:, remap_index, :]
self.out_std = self.out_std[:, remap_index, :]

def forward_common_atomic(
self,
extended_coord: np.ndarray,
Expand Down
18 changes: 18 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,24 @@ def forward_atomic(
)
return ret

def change_type_map(
self, type_map: List[str], model_with_new_type_stat=None
) -> None:
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
"""
super().change_type_map(
type_map=type_map, model_with_new_type_stat=model_with_new_type_stat
)
self.type_map = type_map
self.descriptor.change_type_map(
type_map=type_map,
model_with_new_type_stat=model_with_new_type_stat.descriptor
if model_with_new_type_stat is not None
else None,
)
self.fitting_net.change_type_map(type_map=type_map)

def serialize(self) -> dict:
dd = super().serialize()
dd.update(
Expand Down
17 changes: 17 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,23 @@ def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def change_type_map(
self, type_map: List[str], model_with_new_type_stat=None
) -> None:
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
"""
super().change_type_map(
type_map=type_map, model_with_new_type_stat=model_with_new_type_stat
)
for ii, model in enumerate(self.models):
model.change_type_map(
type_map=type_map,
model_with_new_type_stat=model_with_new_type_stat.models[ii]
if model_with_new_type_stat is not None
else None,
)

def get_model_rcuts(self) -> List[float]:
"""Get the cut-off radius for each individual models."""
return [model.get_rcut() for model in self.models]
Expand Down
6 changes: 6 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ def serialize(self) -> dict:
def deserialize(cls, data: dict):
pass

@abstractmethod
def change_type_map(
self, type_map: List[str], model_with_new_type_stat=None
) -> None:
pass

def make_atom_mask(
self,
atype: t_tensor,
Expand Down
12 changes: 12 additions & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,18 @@ def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return False

def change_type_map(
self, type_map: List[str], model_with_new_type_stat=None
) -> None:
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
"""
assert type_map == self.type_map, (
"PairTabAtomicModel does not support changing type map now. "
"This feature is currently not implemented because it would require additional work to change the tab file. "
"We may consider adding this support in the future if there is a clear demand for it."
)

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
34 changes: 34 additions & 0 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,37 @@ def call(
@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""


def extend_descrpt_stat(des, type_map, des_with_stat=None):
r"""
Extend the statistics of a descriptor block with types from newly provided `type_map`.
After extending, the type related dimension of the extended statistics will have a length of
`len(old_type_map) + len(type_map)`, where `old_type_map` represents the type map in `des`.
The `get_index_between_two_maps()` function can then be used to correctly select statistics for types
from `old_type_map` or `type_map`.
Positive indices from 0 to `len(old_type_map) - 1` will select old statistics of types in `old_type_map`,
while negative indices from `-len(type_map)` to -1 will select new statistics of types in `type_map`.
Parameters
----------
des : DescriptorBlock
The descriptor block to be extended.
type_map : List[str]
The name of each type of atoms to be extended.
des_with_stat : DescriptorBlock, Optional
The descriptor block has additional statistics of types from newly provided `type_map`.
If None, the default statistics will be used.
Otherwise, the statistics provided in this DescriptorBlock will be used.
"""
if des_with_stat is not None:
extend_davg = des_with_stat["davg"]
extend_dstd = des_with_stat["dstd"]
else:
extend_shape = [len(type_map), *list(des["davg"].shape[1:])]
extend_davg = np.zeros(extend_shape, dtype=des["davg"].dtype)
extend_dstd = np.ones(extend_shape, dtype=des["dstd"].dtype)
des["davg"] = np.concatenate([des["davg"], extend_davg], axis=0)
des["dstd"] = np.concatenate([des["dstd"], extend_dstd], axis=0)
43 changes: 41 additions & 2 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.finetune import (
get_index_between_two_maps,
map_pair_exclude_types,
)
from deepmd.utils.path import (
DPPath,
)
Expand All @@ -49,6 +53,7 @@
)
from .descriptor import (
DescriptorBlock,
extend_descrpt_stat,
)


Expand Down Expand Up @@ -194,8 +199,6 @@ class DescrptDPA1(NativeOP, BaseDescriptor):
Whether to use electronic configuration type embedding.
type_map: List[str], Optional
A list of strings. Give the name to each type of atoms.
Only used if `use_econf_tebd` is `True` in type embedding net.
spin
(Only support None to keep consistent with other backend references.)
(Not used in this version. Not-none option is not implemented.)
Expand Down Expand Up @@ -327,6 +330,10 @@ def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.se_atten.get_ntypes()

def get_type_map(self) -> List[str]:
"""Get the name to each type of atoms."""
return self.type_map

def get_dim_out(self) -> int:
"""Returns the output dimension."""
ret = self.se_atten.get_dim_out()
Expand Down Expand Up @@ -382,9 +389,41 @@ def set_stat_mean_and_stddev(
mean: np.ndarray,
stddev: np.ndarray,
) -> None:
"""Update mean and stddev for descriptor."""
self.se_atten.mean = mean
self.se_atten.stddev = stddev

def get_stat_mean_and_stddev(self) -> Tuple[np.ndarray, np.ndarray]:
"""Get mean and stddev for descriptor."""
return self.se_atten.mean, self.se_atten.stddev

def change_type_map(
self, type_map: List[str], model_with_new_type_stat=None
) -> None:
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
"""
assert (
self.type_map is not None
), "'type_map' must be defined when performing type changing!"
remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map)
obj = self.se_atten
obj.ntypes = len(type_map)
self.type_map = type_map
self.type_embedding.change_type_map(type_map=type_map)
obj.reinit_exclude(map_pair_exclude_types(obj.exclude_types, remap_index))
if has_new_type:
# the avg and std of new types need to be updated
extend_descrpt_stat(
obj,
type_map,
des_with_stat=model_with_new_type_stat.se_atten
if model_with_new_type_stat is not None
else None,
)
obj["davg"] = obj["davg"][remap_index]
obj["dstd"] = obj["dstd"][remap_index]

def call(
self,
coord_ext,
Expand Down
70 changes: 69 additions & 1 deletion deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.finetune import (
get_index_between_two_maps,
map_pair_exclude_types,
)
from deepmd.utils.path import (
DPPath,
)
Expand All @@ -42,6 +46,9 @@
from .base_descriptor import (
BaseDescriptor,
)
from .descriptor import (
extend_descrpt_stat,
)
from .dpa1 import (
DescrptBlockSeAtten,
)
Expand Down Expand Up @@ -353,7 +360,6 @@ def __init__(
Whether to use electronic configuration type embedding.
type_map : List[str], Optional
A list of strings. Give the name to each type of atoms.
Only used if `use_econf_tebd` is `True` in type embedding net.
Returns
-------
Expand Down Expand Up @@ -501,6 +507,10 @@ def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes

def get_type_map(self) -> List[str]:
"""Get the name to each type of atoms."""
return self.type_map

def get_dim_out(self) -> int:
"""Returns the output dimension of this descriptor."""
ret = self.repformers.dim_out
Expand Down Expand Up @@ -542,6 +552,47 @@ def share_params(self, base_class, shared_level, resume=False):
"""
raise NotImplementedError

def change_type_map(
self, type_map: List[str], model_with_new_type_stat=None
) -> None:
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
"""
assert (
self.type_map is not None
), "'type_map' must be defined when performing type changing!"
remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map)
self.type_map = type_map
self.type_embedding.change_type_map(type_map=type_map)
self.exclude_types = map_pair_exclude_types(self.exclude_types, remap_index)
self.ntypes = len(type_map)
repinit = self.repinit
repformers = self.repformers
if has_new_type:
# the avg and std of new types need to be updated
extend_descrpt_stat(
repinit,
type_map,
des_with_stat=model_with_new_type_stat.repinit
if model_with_new_type_stat is not None
else None,
)
extend_descrpt_stat(
repformers,
type_map,
des_with_stat=model_with_new_type_stat.repformers
if model_with_new_type_stat is not None
else None,
)
repinit.ntypes = self.ntypes
repformers.ntypes = self.ntypes
repinit.reinit_exclude(self.exclude_types)
repformers.reinit_exclude(self.exclude_types)
repinit["davg"] = repinit["davg"][remap_index]
repinit["dstd"] = repinit["dstd"][remap_index]
repformers["davg"] = repformers["davg"][remap_index]
repformers["dstd"] = repformers["dstd"][remap_index]

@property
def dim_out(self):
return self.get_dim_out()
Expand All @@ -555,6 +606,23 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

def set_stat_mean_and_stddev(
self,
mean: List[np.ndarray],
stddev: List[np.ndarray],
) -> None:
"""Update mean and stddev for descriptor."""
for ii, descrpt in enumerate([self.repinit, self.repformers]):
descrpt.mean = mean[ii]
descrpt.stddev = stddev[ii]

def get_stat_mean_and_stddev(self) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""Get mean and stddev for descriptor."""
return [self.repinit.mean, self.repformers.mean], [
self.repinit.stddev,
self.repformers.stddev,
]

def call(
self,
coord_ext: np.ndarray,
Expand Down
Loading

0 comments on commit a1a3840

Please sign in to comment.