Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refact: compute_output_stats and change_out_bias #3639

Merged
merged 4 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def atomic_output_def(self) -> FittingOutputDef:
"""
return self.fitting_output_def()

def get_output_keys(self) -> List[str]:
return list(self.atomic_output_def().keys())

@abstractmethod
def get_rcut(self) -> float:
"""Get the cut-off radius."""
Expand Down
123 changes: 50 additions & 73 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
List,
Optional,
Tuple,
Union,
)

import numpy as np
import torch

from deepmd.dpmodel.atomic_model import (
Expand All @@ -30,9 +30,6 @@
from deepmd.pt.utils.stat import (
compute_output_stats,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -190,115 +187,95 @@ def serialize(self) -> dict:
"pair_exclude_types": self.pair_exclude_types,
}

def get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""

def model_forward(coord, atype, box, fparam=None, aparam=None):
with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
mixed_types=self.mixed_types(),
box=box,
)
atomic_ret = self.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
return {kk: vv.detach() for kk, vv in atomic_ret.items()}

return model_forward

def compute_or_load_stat(
self,
sampled_func,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
):
"""
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
and saved in the `stat_file_path`(s).
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
and load the calculated statistics parameters.
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.

Parameters
----------
sampled_func
The sampled data frames from different data systems.
stat_file_path
The path to the statistics files.
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
raise NotImplementedError

def change_out_bias(
self,
merged,
origin_type_map,
full_type_map,
sample_merged,
bias_adjust_mode="change-by-statistic",
) -> None:
"""Change the output bias according to the input data and the pretrained model.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
sample_merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
origin_type_map : List[str]
The original type_map in dataset, they are targets to change the output bias.
full_type_map : List[str]
The full type_map in pre-trained model
bias_adjust_mode : str
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
'change-by-statistic' : perform predictions on labels of target dataset,
and do least square on the errors to obtain the target shift as bias.
'set-by-statistic' : directly use the statistic output bias in the target dataset.
"""
sorter = np.argsort(full_type_map)
missing_types = [t for t in origin_type_map if t not in full_type_map]
assert (
not missing_types
), f"Some types are not in the pre-trained model: {list(missing_types)} !"
idx_type_map = sorter[
np.searchsorted(full_type_map, origin_type_map, sorter=sorter)
]
original_bias = self.get_out_bias()
if bias_adjust_mode == "change-by-statistic":
delta_bias = compute_output_stats(
merged,
sample_merged,
self.get_ntypes(),
keys=["energy"],
model_forward=self.get_forward_wrapper_func(),
keys=self.get_output_keys(),
model_forward=self._get_forward_wrapper_func(),
)["energy"]
self.set_out_bias(delta_bias, add=True)
elif bias_adjust_mode == "set-by-statistic":
bias_atom = compute_output_stats(
merged,
sample_merged,
self.get_ntypes(),
keys=["energy"],
keys=self.get_output_keys(),
)["energy"]
self.set_out_bias(bias_atom)
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
bias_atom = self.get_out_bias()
log.info(
f"Change output bias of {origin_type_map!s} "
f"from {to_numpy_array(original_bias[idx_type_map]).reshape(-1)!s} "
f"to {to_numpy_array(bias_atom[idx_type_map]).reshape(-1)!s}."
)

def _get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""

def model_forward(coord, atype, box, fparam=None, aparam=None):
with torch.no_grad(): # it's essential for pure torch forward function to use auto_batchsize
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
mixed_types=self.mixed_types(),
box=box,
)
atomic_ret = self.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
return {kk: vv.detach() for kk, vv in atomic_ret.items()}

return model_forward
11 changes: 3 additions & 8 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,12 @@ def forward_common(
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

def get_out_bias(self) -> torch.Tensor:
return self.atomic_model.get_out_bias()

def change_out_bias(
self,
merged,
origin_type_map,
full_type_map,
bias_adjust_mode="change-by-statistic",
) -> None:
"""Change the output bias of atomic model according to the input data and the pretrained model.
Expand All @@ -190,10 +191,6 @@ def change_out_bias(
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
origin_type_map : List[str]
The original type_map in dataset, they are targets to change the output bias.
full_type_map : List[str]
The full type_map in pre-trained model
bias_adjust_mode : str
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
'change-by-statistic' : perform predictions on labels of target dataset,
Expand All @@ -202,8 +199,6 @@ def change_out_bias(
"""
self.atomic_model.change_out_bias(
merged,
origin_type_map,
full_type_map,
bias_adjust_mode=bias_adjust_mode,
)

Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def compute_output_stats(
bias_atom_e = compute_output_stats(
merged,
self.ntypes,
keys=["energy"],
keys=[self.var_name],
stat_file_path=stat_file_path,
rcond=self.rcond,
atom_ener=self.atom_ener,
)["energy"]
)[self.var_name]
self.bias_atom_e.copy_(bias_atom_e.view([self.ntypes, self.dim_out]))

def output_def(self) -> FittingOutputDef:
Expand Down
37 changes: 30 additions & 7 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,13 +570,8 @@
_model_params["new_type_map"],
)
if isinstance(_model, EnergyModel):
_model.change_out_bias(
_sample_func,
bias_adjust_mode=_model_params.get(
"bias_adjust_mode", "change-by-statistic"
),
origin_type_map=new_type_map,
full_type_map=old_type_map,
_model = _model_change_out_bias(
Fixed Show fixed Hide fixed
_model, new_type_map, _sample_func, _model_params
)
else:
# need to updated
Expand Down Expand Up @@ -1148,3 +1143,31 @@
print_str += " %8.1e\n" % cur_lr
fout.write(print_str)
fout.flush()


def _model_change_out_bias(
_model,
new_type_map,
_sample_func,
_model_params,
):
old_bias = _model.get_out_bias()
_model.change_out_bias(
_sample_func,
bias_adjust_mode=_model_params.get("bias_adjust_mode", "change-by-statistic"),
)
new_bias = _model.get_out_bias()

model_type_map = _model.get_type_map()
sorter = np.argsort(model_type_map)
missing_types = [t for t in new_type_map if t not in model_type_map]
assert (
not missing_types
), f"Some types are not in the pre-trained model: {list(missing_types)} !"
idx_type_map = sorter[np.searchsorted(model_type_map, new_type_map, sorter=sorter)]
log.info(
f"Change output bias of {new_type_map!s} "
f"from {to_numpy_array(old_bias[idx_type_map]).reshape(-1)!s} "
f"to {to_numpy_array(new_bias[idx_type_map]).reshape(-1)!s}."
)
return _model
Loading