Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

breaking: pt: remove data stat from model init #3245

Merged
merged 9 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def compute_input_stats(self, merged):
pass

@abstractmethod
def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):
def init_desc_stat(self, stat_dict):
"""Initialize the model bias by the statistics."""
pass

Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,14 @@ def __getitem__(self, key):
else:
raise KeyError(key)

def compute_output_stats(self, merged):
"""Update the output bias for fitting net."""
raise NotImplementedError

def init_fitting_stat(self, result_dict):
"""Initialize the model bias by the statistics."""
raise NotImplementedError

def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/fitting/make_base_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ def fwd(
"""Calculate fitting."""
pass

@abstractmethod
def compute_output_stats(self, merged):
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
"""Update the output bias for fitting net."""
pass

@abstractmethod
def init_fitting_stat(self, result_dict):
"""Initialize the model bias by the statistics."""
pass

@abstractmethod
def serialize(self) -> dict:
"""Serialize the obj to dict."""
Expand Down
26 changes: 26 additions & 0 deletions deepmd/dpmodel/model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,32 @@ def serialize(self) -> dict:
def deserialize(cls):
pass

def compute_or_load_stat(
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
self,
type_map=None,
sampled=None,
stat_file_path=None,
):
"""
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
and saved in the `stat_file_path`(s).
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
and load the calculated statistics parameters.

Parameters
----------
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
sampled
The sampled data frames from different data systems.
stat_file_path
The path to the statistics files.
"""
pass
iProzd marked this conversation as resolved.
Show resolved Hide resolved

def do_grad(
self,
var_name: Optional[str] = None,
Expand Down
71 changes: 32 additions & 39 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
from deepmd.pt.model.descriptor import (
Descriptor,
)
from deepmd.pt.model.task import (
Fitting,
)
from deepmd.pt.train import (
training,
)
Expand Down Expand Up @@ -122,49 +125,39 @@
hybrid_descrpt = model_params_single["descriptor"]["type"] == "hybrid"
has_stat_file_path = True
if not hybrid_descrpt:
### this design requires "rcut", "rcut_smth" and "sel" in the descriptor
### VERY BAD DESIGN!!!!
### not all descriptors provides these parameter in their constructor
default_stat_file_name = Descriptor.get_stat_name(
model_params_single["descriptor"]
)
model_params_single["stat_file_dir"] = data_dict_single.get(
"stat_file_dir", f"stat_files{suffix}"
)
model_params_single["stat_file"] = data_dict_single.get(
"stat_file", default_stat_file_name
)
model_params_single["stat_file_path"] = os.path.join(
model_params_single["stat_file_dir"], model_params_single["stat_file"]
)
if not os.path.exists(model_params_single["stat_file_path"]):
stat_file = data_dict_single.get("stat_file", None)
if stat_file is None:
stat_file = {}
if "descriptor" in model_params_single:
default_stat_file_name_descrpt = Descriptor.get_stat_name(
model_params_single["descriptor"],
len(model_params_single["type_map"]),
)
stat_file["descriptor"] = default_stat_file_name_descrpt
if "fitting_net" in model_params_single:
default_stat_file_name_fitting = Fitting.get_stat_name(
model_params_single["fitting_net"],
len(model_params_single["type_map"]),
)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
stat_file["fitting_net"] = default_stat_file_name_fitting
model_params_single["stat_file_path"] = {
Fixed Show fixed Hide fixed
key: os.path.join(model_params_single["stat_file_dir"], stat_file[key])
for key in stat_file
}

has_stat_file_path_list = [
os.path.exists(model_params_single["stat_file_path"][key])
for key in stat_file
]
if False in has_stat_file_path_list:
has_stat_file_path = False
else: ### need to remove this
default_stat_file_name = []
for descrpt in model_params_single["descriptor"]["list"]:
default_stat_file_name.append(
f'stat_file_rcut{descrpt["rcut"]:.2f}_'
f'smth{descrpt["rcut_smth"]:.2f}_'
f'sel{descrpt["sel"]}_{descrpt["type"]}.npz'
)
model_params_single["stat_file_dir"] = data_dict_single.get(
"stat_file_dir", f"stat_files{suffix}"
else: ### TODO hybrid descriptor not implemented
raise NotImplementedError(
"data stat for hybrid descriptor is not implemented!"
)
model_params_single["stat_file"] = data_dict_single.get(
"stat_file", default_stat_file_name
)
assert isinstance(
model_params_single["stat_file"], list
), "Stat file of hybrid descriptor must be a list!"
stat_file_path = []
for stat_file_path_item in model_params_single["stat_file"]:
single_file_path = os.path.join(
model_params_single["stat_file_dir"], stat_file_path_item
)
stat_file_path.append(single_file_path)
if not os.path.exists(single_file_path):
has_stat_file_path = False
model_params_single["stat_file_path"] = stat_file_path

# validation and training data
validation_data_single = DpLoaderSet(
Expand Down Expand Up @@ -227,7 +220,7 @@
trainer = training.Trainer(
config,
train_data,
sampled,
sampled=sampled,
validation_data=validation_data,
init_model=init_model,
restart_model=restart_model,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
assert not self.multi_task, "multitask mode currently not supported!"
self.type_split = self.input_param["descriptor"]["type"] in ["se_e2_a"]
self.type_map = self.input_param["type_map"]
self.dp = ModelWrapper(get_model(self.input_param, None).to(DEVICE))
self.dp = ModelWrapper(get_model(self.input_param).to(DEVICE))
self.dp.load_state_dict(state_dict)
self.rcut = self.dp.model["Default"].descriptor.get_rcut()
self.sec = np.cumsum(self.dp.model["Default"].descriptor.get_sel())
Expand Down
137 changes: 127 additions & 10 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from abc import (
ABC,
abstractmethod,
Expand All @@ -7,6 +8,7 @@
Callable,
List,
Optional,
Union,
)

import numpy as np
Expand All @@ -23,6 +25,8 @@
BaseDescriptor,
)

log = logging.getLogger(__name__)


class Descriptor(torch.nn.Module, BaseDescriptor):
"""The descriptor.
Expand Down Expand Up @@ -56,15 +60,127 @@ class SomeDescript(Descriptor):
return Descriptor.__plugins.register(key)

@classmethod
def get_stat_name(cls, config):
def get_stat_name(cls, config, ntypes):
"""
Get the name for the statistic file of the descriptor.
Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name.
"""
if cls is not Descriptor:
raise NotImplementedError("get_stat_name is not implemented!")
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config)
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config, ntypes)

@classmethod
def get_data_process_key(cls, config):
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"""
Get the keys for the data preprocess.
Usually need the information of rcut and sel.
TODO Need to be deprecated when the dataloader has been cleaned up.
"""
if cls is not Descriptor:
raise NotImplementedError("get_data_process_key is not implemented!")
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config)

def get_data_stat_key(self):
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"""
Get the keys for the data statistic of the descriptor.
Return a list of statistic names needed, such as "sumr", "suma" or "sumn".
"""
raise NotImplementedError("get_data_stat_key is not implemented!")

def set_stats(
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
self,
type_map: List[str],
sampled,
stat_file_path: Optional[Union[str, List[str]]] = None,
):
"""
Set the statistics parameters for the descriptor.
Calculate and save the mean and standard deviation of the descriptor to `stat_file_path`
if `sampled` is not None, otherwise load them from `stat_file_path`.

Parameters
----------
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
sampled
The sampled data frames from different data systems.
stat_file_path
The path to the statistics files.
"""
# TODO support hybrid descriptor
descrpt_stat_key = self.get_data_stat_key()
if sampled is not None: # compute the statistics results
tmp_dict = self.compute_input_stats(sampled)
result_dict = {key: tmp_dict[key] for key in descrpt_stat_key}
result_dict["type_map"] = type_map
if stat_file_path is not None:
self.save_stats(result_dict, stat_file_path)
else: # load the statistics results
assert stat_file_path is not None, "No stat file to load!"
result_dict = self.load_stats(type_map, stat_file_path)
self.init_desc_stat(result_dict)

def save_stats(self, result_dict, stat_file_path: Union[str, List[str]]):
"""
Save the statistics results to `stat_file_path`.

Parameters
----------
result_dict
The dictionary of statistics results.
stat_file_path
The path to the statistics file(s).
"""
if not isinstance(stat_file_path, list):
log.info(f"Saving stat file to {stat_file_path}")
np.savez_compressed(stat_file_path, **result_dict)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
else: # TODO hybrid descriptor not implemented
raise NotImplementedError(
"save_stats for hybrid descriptor is not implemented!"
)

def load_stats(self, type_map, stat_file_path: Union[str, List[str]]):
"""
Load the statistics results to `stat_file_path`.

Parameters
----------
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
stat_file_path
The path to the statistics file(s).

Returns
-------
result_dict
The dictionary of statistics results.
"""
descrpt_stat_key = self.get_data_stat_key()
target_type_map = type_map
if not isinstance(stat_file_path, list):
log.info(f"Loading stat file from {stat_file_path}")
stats = np.load(stat_file_path)
stat_type_map = list(stats["type_map"])
missing_type = [i for i in target_type_map if i not in stat_type_map]
assert not missing_type, (
f"These type are not in stat file {stat_file_path}: {missing_type}! "
f"Please change the stat file path!"
)
idx_map = [stat_type_map.index(i) for i in target_type_map]
if stats[descrpt_stat_key[0]].size: # not empty
result_dict = {key: stats[key][idx_map] for key in descrpt_stat_key}
else:
result_dict = {key: [] for key in descrpt_stat_key}
else: # TODO hybrid descriptor not implemented
raise NotImplementedError(
"load_stats for hybrid descriptor is not implemented!"
)
return result_dict

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
try:
Expand Down Expand Up @@ -162,7 +278,7 @@ def compute_input_stats(self, merged):
pass

@abstractmethod
def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):
def init_desc_stat(self, stat_dict):
"""Initialize the model bias by the statistics."""
iProzd marked this conversation as resolved.
Show resolved Hide resolved
pass

Expand All @@ -188,13 +304,14 @@ def share_params(self, base_class, shared_level, resume=False):
self.sumr2,
self.suma2,
)
base_class.init_desc_stat(
sumr_base + sumr,
suma_base + suma,
sumn_base + sumn,
sumr2_base + sumr2,
suma2_base + suma2,
)
stat_dict = {
"sumr": sumr_base + sumr,
"suma": suma_base + suma,
"sumn": sumn_base + sumn,
"sumr2": sumr2_base + sumr2,
"suma2": suma2_base + suma2,
}
base_class.init_desc_stat(stat_dict)
self.mean = base_class.mean
self.stddev = base_class.stddev
# self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model
Expand Down
Loading