diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index 2b0025af07..29d3ad6d92 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -69,15 +69,13 @@ def distinguish_types(self) -> bool: """ pass - @abstractmethod def compute_input_stats(self, merged): """Update mean and stddev for descriptor elements.""" - pass + raise NotImplementedError - @abstractmethod - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + def init_desc_stat(self, **kwargs): """Initialize the model bias by the statistics.""" - pass + raise NotImplementedError @abstractmethod def fwd( diff --git a/deepmd/dpmodel/fitting/invar_fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py index 820f422ef0..58607a9f26 100644 --- a/deepmd/dpmodel/fitting/invar_fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -236,6 +236,14 @@ def __getitem__(self, key): else: raise KeyError(key) + def compute_output_stats(self, merged): + """Update the output bias for fitting net.""" + raise NotImplementedError + + def init_fitting_stat(self, result_dict): + """Initialize the model bias by the statistics.""" + raise NotImplementedError + def serialize(self) -> dict: """Serialize the fitting to dict.""" return { diff --git a/deepmd/dpmodel/fitting/make_base_fitting.py b/deepmd/dpmodel/fitting/make_base_fitting.py index 719ac6169e..620ff316f1 100644 --- a/deepmd/dpmodel/fitting/make_base_fitting.py +++ b/deepmd/dpmodel/fitting/make_base_fitting.py @@ -52,6 +52,14 @@ def fwd( """Calculate fitting.""" pass + def compute_output_stats(self, merged): + """Update the output bias for fitting net.""" + raise NotImplementedError + + def init_fitting_stat(self, **kwargs): + """Initialize the model bias by the statistics.""" + raise NotImplementedError + @abstractmethod def serialize(self) -> dict: """Serialize the obj to dict.""" diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 04bb2c2d7e..702fc6f317 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -46,6 +46,9 @@ from deepmd.pt.model.descriptor import ( Descriptor, ) +from deepmd.pt.model.task import ( + Fitting, +) from deepmd.pt.train import ( training, ) @@ -63,6 +66,7 @@ ) from deepmd.pt.utils.stat import ( make_stat_input, + process_stat_path, ) from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter @@ -128,51 +132,18 @@ def prepare_trainer_input_single( # stat files hybrid_descrpt = model_params_single["descriptor"]["type"] == "hybrid" - has_stat_file_path = True if not hybrid_descrpt: - ### this design requires "rcut", "rcut_smth" and "sel" in the descriptor - ### VERY BAD DESIGN!!!! - ### not all descriptors provides these parameter in their constructor - default_stat_file_name = Descriptor.get_stat_name( - model_params_single["descriptor"] - ) - model_params_single["stat_file_dir"] = data_dict_single.get( - "stat_file_dir", f"stat_files{suffix}" - ) - model_params_single["stat_file"] = data_dict_single.get( - "stat_file", default_stat_file_name - ) - model_params_single["stat_file_path"] = os.path.join( - model_params_single["stat_file_dir"], model_params_single["stat_file"] - ) - if not os.path.exists(model_params_single["stat_file_path"]): - has_stat_file_path = False - else: ### need to remove this - default_stat_file_name = [] - for descrpt in model_params_single["descriptor"]["list"]: - default_stat_file_name.append( - f'stat_file_rcut{descrpt["rcut"]:.2f}_' - f'smth{descrpt["rcut_smth"]:.2f}_' - f'sel{descrpt["sel"]}_{descrpt["type"]}.npz' - ) - model_params_single["stat_file_dir"] = data_dict_single.get( - "stat_file_dir", f"stat_files{suffix}" + stat_file_path_single, has_stat_file_path = process_stat_path( + data_dict_single.get("stat_file", None), + data_dict_single.get("stat_file_dir", f"stat_files{suffix}"), + model_params_single, + Descriptor, + Fitting, ) - model_params_single["stat_file"] = data_dict_single.get( - "stat_file", default_stat_file_name + else: ### TODO hybrid descriptor not implemented + raise NotImplementedError( + "data stat for hybrid descriptor is not implemented!" ) - assert isinstance( - model_params_single["stat_file"], list - ), "Stat file of hybrid descriptor must be a list!" - stat_file_path = [] - for stat_file_path_item in model_params_single["stat_file"]: - single_file_path = os.path.join( - model_params_single["stat_file_dir"], stat_file_path_item - ) - stat_file_path.append(single_file_path) - if not os.path.exists(single_file_path): - has_stat_file_path = False - model_params_single["stat_file_path"] = stat_file_path # validation and training data validation_data_single = DpLoaderSet( @@ -212,19 +183,30 @@ def prepare_trainer_input_single( type_split=type_split, noise_settings=noise_settings, ) - return train_data_single, validation_data_single, sampled_single + return ( + train_data_single, + validation_data_single, + sampled_single, + stat_file_path_single, + ) if not multi_task: - train_data, validation_data, sampled = prepare_trainer_input_single( + ( + train_data, + validation_data, + sampled, + stat_file_path, + ) = prepare_trainer_input_single( config["model"], config["training"], config["loss"] ) else: - train_data, validation_data, sampled = {}, {}, {} + train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {} for model_key in config["model"]["model_dict"]: ( train_data[model_key], validation_data[model_key], sampled[model_key], + stat_file_path[model_key], ) = prepare_trainer_input_single( config["model"]["model_dict"][model_key], config["training"]["data_dict"][model_key], @@ -235,7 +217,8 @@ def prepare_trainer_input_single( trainer = training.Trainer( config, train_data, - sampled, + sampled=sampled, + stat_file_path=stat_file_path, validation_data=validation_data, init_model=init_model, restart_model=restart_model, diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 325b1a56a4..9542ff33b1 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -99,7 +99,7 @@ def __init__( self.input_param["resuming"] = True self.multi_task = "model_dict" in self.input_param assert not self.multi_task, "multitask mode currently not supported!" - model = get_model(self.input_param, None).to(DEVICE) + model = get_model(self.input_param).to(DEVICE) model = torch.jit.script(model) self.dp = ModelWrapper(model) self.dp.load_state_dict(state_dict) diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index b4e866bb11..177f30d241 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging from abc import ( ABC, abstractmethod, @@ -7,6 +8,7 @@ Callable, List, Optional, + Union, ) import numpy as np @@ -23,6 +25,8 @@ BaseDescriptor, ) +log = logging.getLogger(__name__) + class Descriptor(torch.nn.Module, BaseDescriptor): """The descriptor. @@ -56,15 +60,130 @@ class SomeDescript(Descriptor): return Descriptor.__plugins.register(key) @classmethod - def get_stat_name(cls, config): - descrpt_type = config["type"] - return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config) + def get_stat_name(cls, ntypes, type_name, **kwargs): + """ + Get the name for the statistic file of the descriptor. + Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name. + """ + if cls is not Descriptor: + raise NotImplementedError("get_stat_name is not implemented!") + descrpt_type = type_name + return Descriptor.__plugins.plugins[descrpt_type].get_stat_name( + ntypes, type_name, **kwargs + ) @classmethod def get_data_process_key(cls, config): + """ + Get the keys for the data preprocess. + Usually need the information of rcut and sel. + TODO Need to be deprecated when the dataloader has been cleaned up. + """ + if cls is not Descriptor: + raise NotImplementedError("get_data_process_key is not implemented!") descrpt_type = config["type"] return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config) + @property + def data_stat_key(self): + """ + Get the keys for the data statistic of the descriptor. + Return a list of statistic names needed, such as "sumr", "suma" or "sumn". + """ + raise NotImplementedError("data_stat_key is not implemented!") + + def compute_or_load_stat( + self, + type_map: List[str], + sampled=None, + stat_file_path: Optional[Union[str, List[str]]] = None, + ): + """ + Compute or load the statistics parameters of the descriptor. + Calculate and save the mean and standard deviation of the descriptor to `stat_file_path` + if `sampled` is not None, otherwise load them from `stat_file_path`. + + Parameters + ---------- + type_map + Mapping atom type to the name (str) of the type. + For example `type_map[1]` gives the name of the type 1. + sampled + The sampled data frames from different data systems. + stat_file_path + The path to the statistics files. + """ + # TODO support hybrid descriptor + descrpt_stat_key = self.data_stat_key + if sampled is not None: # compute the statistics results + tmp_dict = self.compute_input_stats(sampled) + result_dict = {key: tmp_dict[key] for key in descrpt_stat_key} + result_dict["type_map"] = type_map + if stat_file_path is not None: + self.save_stats(result_dict, stat_file_path) + else: # load the statistics results + assert stat_file_path is not None, "No stat file to load!" + result_dict = self.load_stats(type_map, stat_file_path) + self.init_desc_stat(**result_dict) + + def save_stats(self, result_dict, stat_file_path: Union[str, List[str]]): + """ + Save the statistics results to `stat_file_path`. + + Parameters + ---------- + result_dict + The dictionary of statistics results. + stat_file_path + The path to the statistics file(s). + """ + if not isinstance(stat_file_path, list): + log.info(f"Saving stat file to {stat_file_path}") + np.savez_compressed(stat_file_path, **result_dict) + else: # TODO hybrid descriptor not implemented + raise NotImplementedError( + "save_stats for hybrid descriptor is not implemented!" + ) + + def load_stats(self, type_map, stat_file_path: Union[str, List[str]]): + """ + Load the statistics results to `stat_file_path`. + + Parameters + ---------- + type_map + Mapping atom type to the name (str) of the type. + For example `type_map[1]` gives the name of the type 1. + stat_file_path + The path to the statistics file(s). + + Returns + ------- + result_dict + The dictionary of statistics results. + """ + descrpt_stat_key = self.data_stat_key + target_type_map = type_map + if not isinstance(stat_file_path, list): + log.info(f"Loading stat file from {stat_file_path}") + stats = np.load(stat_file_path) + stat_type_map = list(stats["type_map"]) + missing_type = [i for i in target_type_map if i not in stat_type_map] + assert not missing_type, ( + f"These type are not in stat file {stat_file_path}: {missing_type}! " + f"Please change the stat file path!" + ) + idx_map = [stat_type_map.index(i) for i in target_type_map] + if stats[descrpt_stat_key[0]].size: # not empty + result_dict = {key: stats[key][idx_map] for key in descrpt_stat_key} + else: + result_dict = {key: [] for key in descrpt_stat_key} + else: # TODO hybrid descriptor not implemented + raise NotImplementedError( + "load_stats for hybrid descriptor is not implemented!" + ) + return result_dict + def __new__(cls, *args, **kwargs): if cls is Descriptor: try: @@ -156,15 +275,13 @@ def get_dim_emb(self) -> int: """Returns the embedding dimension.""" pass - @abstractmethod def compute_input_stats(self, merged): """Update mean and stddev for DescriptorBlock elements.""" - pass + raise NotImplementedError - @abstractmethod - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): - """Initialize the model bias by the statistics.""" - pass + def init_desc_stat(self, **kwargs): + """Initialize mean and stddev by the statistics.""" + raise NotImplementedError def share_params(self, base_class, shared_level, resume=False): assert ( @@ -188,13 +305,14 @@ def share_params(self, base_class, shared_level, resume=False): self.sumr2, self.suma2, ) - base_class.init_desc_stat( - sumr_base + sumr, - suma_base + suma, - sumn_base + sumn, - sumr2_base + sumr2, - suma2_base + suma2, - ) + stat_dict = { + "sumr": sumr_base + sumr, + "suma": suma_base + suma, + "sumn": sumn_base + sumn, + "sumr2": sumr2_base + sumr2, + "suma2": suma2_base + suma2, + } + base_class.init_desc_stat(**stat_dict) self.mean = base_class.mean self.stddev = base_class.stddev # self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 914c37ed51..6c1331ec1d 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging from typing import ( List, Optional, @@ -17,6 +18,8 @@ DescrptBlockSeAtten, ) +log = logging.getLogger(__name__) + @Descriptor.register("dpa1") @Descriptor.register("se_atten") @@ -122,21 +125,43 @@ def dim_emb(self): def compute_input_stats(self, merged): return self.se_atten.compute_input_stats(merged) - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + def init_desc_stat( + self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs + ): + assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]] self.se_atten.init_desc_stat(sumr, suma, sumn, sumr2, suma2) @classmethod - def get_stat_name(cls, config): - descrpt_type = config["type"] + def get_stat_name( + cls, ntypes, type_name, rcut=None, rcut_smth=None, sel=None, **kwargs + ): + """ + Get the name for the statistic file of the descriptor. + Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name. + """ + descrpt_type = type_name assert descrpt_type in ["dpa1", "se_atten"] - return f'stat_file_dpa1_rcut{config["rcut"]:.2f}_smth{config["rcut_smth"]:.2f}_sel{config["sel"]}.npz' + return f"stat_file_descrpt_dpa1_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz" @classmethod def get_data_process_key(cls, config): + """ + Get the keys for the data preprocess. + Usually need the information of rcut and sel. + TODO Need to be deprecated when the dataloader has been cleaned up. + """ descrpt_type = config["type"] assert descrpt_type in ["dpa1", "se_atten"] return {"sel": config["sel"], "rcut": config["rcut"]} + @property + def data_stat_key(self): + """ + Get the keys for the data statistic of the descriptor. + Return a list of statistic names needed, such as "sumr", "suma" or "sumn". + """ + return ["sumr", "suma", "sumn", "sumr2", "suma2"] + def serialize(self) -> dict: """Serialize the obj to dict.""" raise NotImplementedError diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index b40e466ed4..05e7cec658 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging from typing import ( List, Optional, @@ -26,6 +27,8 @@ DescrptBlockSeAtten, ) +log = logging.getLogger(__name__) + @Descriptor.register("dpa2") class DescrptDPA2(Descriptor): @@ -296,35 +299,76 @@ def compute_input_stats(self, merged): } for item in merged ] - ( - sumr_tmp, - suma_tmp, - sumn_tmp, - sumr2_tmp, - suma2_tmp, - ) = descrpt.compute_input_stats(merged_tmp) - sumr.append(sumr_tmp) - suma.append(suma_tmp) - sumn.append(sumn_tmp) - sumr2.append(sumr2_tmp) - suma2.append(suma2_tmp) - return sumr, suma, sumn, sumr2, suma2 + tmp_stat_dict = descrpt.compute_input_stats(merged_tmp) + sumr.append(tmp_stat_dict["sumr"]) + suma.append(tmp_stat_dict["suma"]) + sumn.append(tmp_stat_dict["sumn"]) + sumr2.append(tmp_stat_dict["sumr2"]) + suma2.append(tmp_stat_dict["suma2"]) + return { + "sumr": sumr, + "suma": suma, + "sumn": sumn, + "sumr2": sumr2, + "suma2": suma2, + } - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + def init_desc_stat( + self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs + ): + assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]] for ii, descrpt in enumerate([self.repinit, self.repformers]): - descrpt.init_desc_stat(sumr[ii], suma[ii], sumn[ii], sumr2[ii], suma2[ii]) + stat_dict_ii = { + "sumr": sumr[ii], + "suma": suma[ii], + "sumn": sumn[ii], + "sumr2": sumr2[ii], + "suma2": suma2[ii], + } + descrpt.init_desc_stat(**stat_dict_ii) @classmethod - def get_stat_name(cls, config): - descrpt_type = config["type"] + def get_stat_name( + cls, + ntypes, + type_name, + repinit_rcut=None, + repinit_rcut_smth=None, + repinit_nsel=None, + repformer_rcut=None, + repformer_rcut_smth=None, + repformer_nsel=None, + **kwargs, + ): + """ + Get the name for the statistic file of the descriptor. + Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name. + """ + descrpt_type = type_name assert descrpt_type in ["dpa2"] + assert True not in [ + x is None + for x in [ + repinit_rcut, + repinit_rcut_smth, + repinit_nsel, + repformer_rcut, + repformer_rcut_smth, + repformer_nsel, + ] + ] return ( - f'stat_file_dpa2_repinit_rcut{config["repinit_rcut"]:.2f}_smth{config["repinit_rcut_smth"]:.2f}_sel{config["repinit_nsel"]}' - f'_repformer_rcut{config["repformer_rcut"]:.2f}_smth{config["repformer_rcut_smth"]:.2f}_sel{config["repformer_nsel"]}.npz' + f"stat_file_descrpt_dpa2_repinit_rcut{repinit_rcut:.2f}_smth{repinit_rcut_smth:.2f}_sel{repinit_nsel}" + f"_repformer_rcut{repformer_rcut:.2f}_smth{repformer_rcut_smth:.2f}_sel{repformer_nsel}_ntypes{ntypes}.npz" ) @classmethod def get_data_process_key(cls, config): + """ + Get the keys for the data preprocess. + Usually need the information of rcut and sel. + TODO Need to be deprecated when the dataloader has been cleaned up. + """ descrpt_type = config["type"] assert descrpt_type in ["dpa2"] return { @@ -332,6 +376,14 @@ def get_data_process_key(cls, config): "rcut": [config["repinit_rcut"], config["repformer_rcut"]], } + @property + def data_stat_key(self): + """ + Get the keys for the data statistic of the descriptor. + Return a list of statistic names needed, such as "sumr", "suma" or "sumn". + """ + return ["sumr", "suma", "sumn", "sumr2", "suma2"] + def serialize(self) -> dict: """Serialize the obj to dict.""" raise NotImplementedError diff --git a/deepmd/pt/model/descriptor/gaussian_lcc.py b/deepmd/pt/model/descriptor/gaussian_lcc.py index 26ec1175b8..0972b90279 100644 --- a/deepmd/pt/model/descriptor/gaussian_lcc.py +++ b/deepmd/pt/model/descriptor/gaussian_lcc.py @@ -158,7 +158,7 @@ def compute_input_stats(self, merged): """Update mean and stddev for descriptor elements.""" return [], [], [], [], [] - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2, **kwargs): pass def forward( diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index 0698992659..fb7e374ede 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -153,23 +153,33 @@ def compute_input_stats(self, merged): } for item in merged ] - ( - sumr_tmp, - suma_tmp, - sumn_tmp, - sumr2_tmp, - suma2_tmp, - ) = descrpt.compute_input_stats(merged_tmp) - sumr.append(sumr_tmp) - suma.append(suma_tmp) - sumn.append(sumn_tmp) - sumr2.append(sumr2_tmp) - suma2.append(suma2_tmp) - return sumr, suma, sumn, sumr2, suma2 + tmp_stat_dict = descrpt.compute_input_stats(merged_tmp) + sumr.append(tmp_stat_dict["sumr"]) + suma.append(tmp_stat_dict["suma"]) + sumn.append(tmp_stat_dict["sumn"]) + sumr2.append(tmp_stat_dict["sumr2"]) + suma2.append(tmp_stat_dict["suma2"]) + return { + "sumr": sumr, + "suma": suma, + "sumn": sumn, + "sumr2": sumr2, + "suma2": suma2, + } - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + def init_desc_stat( + self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs + ): + assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]] for ii, descrpt in enumerate(self.descriptor_list): - descrpt.init_desc_stat(sumr[ii], suma[ii], sumn[ii], sumr2[ii], suma2[ii]) + stat_dict_ii = { + "sumr": sumr[ii], + "suma": suma[ii], + "sumn": sumn[ii], + "sumr2": sumr2[ii], + "suma2": suma2[ii], + } + descrpt.init_desc_stat(**stat_dict_ii) def forward( self, diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 853962de69..0a302b6f92 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -321,9 +321,15 @@ def compute_input_stats(self, merged): sumn = np.sum(sumn, axis=0) sumr2 = np.sum(sumr2, axis=0) suma2 = np.sum(suma2, axis=0) - return sumr, suma, sumn, sumr2, suma2 + return { + "sumr": sumr, + "suma": suma, + "sumn": sumn, + "sumr2": sumr2, + "suma2": suma2, + } - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2, **kwargs): all_davg = [] all_dstd = [] for type_i in range(self.ntypes): diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 23b78dcf34..82e7e5185a 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging from typing import ( ClassVar, List, @@ -37,6 +38,8 @@ TypeFilter, ) +log = logging.getLogger(__name__) + @Descriptor.register("se_e2_a") class DescrptSeA(Descriptor): @@ -108,21 +111,44 @@ def compute_input_stats(self, merged): """Update mean and stddev for descriptor elements.""" return self.sea.compute_input_stats(merged) - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + def init_desc_stat( + self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs + ): + assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]] self.sea.init_desc_stat(sumr, suma, sumn, sumr2, suma2) @classmethod - def get_stat_name(cls, config): - descrpt_type = config["type"] + def get_stat_name( + cls, ntypes, type_name, rcut=None, rcut_smth=None, sel=None, **kwargs + ): + """ + Get the name for the statistic file of the descriptor. + Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name. + """ + descrpt_type = type_name assert descrpt_type in ["se_e2_a"] - return f'stat_file_sea_rcut{config["rcut"]:.2f}_smth{config["rcut_smth"]:.2f}_sel{config["sel"]}.npz' + assert True not in [x is None for x in [rcut, rcut_smth, sel]] + return f"stat_file_descrpt_sea_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz" @classmethod def get_data_process_key(cls, config): + """ + Get the keys for the data preprocess. + Usually need the information of rcut and sel. + TODO Need to be deprecated when the dataloader has been cleaned up. + """ descrpt_type = config["type"] assert descrpt_type in ["se_e2_a"] return {"sel": config["sel"], "rcut": config["rcut"]} + @property + def data_stat_key(self): + """ + Get the keys for the data statistic of the descriptor. + Return a list of statistic names needed, such as "sumr", "suma" or "sumn". + """ + return ["sumr", "suma", "sumn", "sumr2", "suma2"] + def forward( self, coord_ext: torch.Tensor, @@ -380,9 +406,15 @@ def compute_input_stats(self, merged): sumn = np.sum(sumn, axis=0) sumr2 = np.sum(sumr2, axis=0) suma2 = np.sum(suma2, axis=0) - return sumr, suma, sumn, sumr2, suma2 + return { + "sumr": sumr, + "suma": suma, + "sumn": sumn, + "sumr2": sumr2, + "suma2": suma2, + } - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2, **kwargs): all_davg = [] all_dstd = [] for type_i in range(self.ntypes): diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 5d6e16fb96..3469d43e40 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -219,9 +219,15 @@ def compute_input_stats(self, merged): sumn = np.sum(sumn, axis=0) sumr2 = np.sum(sumr2, axis=0) suma2 = np.sum(suma2, axis=0) - return sumr, suma, sumn, sumr2, suma2 + return { + "sumr": sumr, + "suma": suma, + "sumn": sumn, + "sumr2": sumr2, + "suma2": suma2, + } - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2, **kwargs): all_davg = [] all_dstd = [] for type_i in range(self.ntypes): diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 6cbab5af4d..1948acd003 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -23,7 +23,7 @@ ) -def get_zbl_model(model_params, sampled=None): +def get_zbl_model(model_params): model_params = copy.deepcopy(model_params) ntypes = len(model_params["type_map"]) # descriptor @@ -41,9 +41,7 @@ def get_zbl_model(model_params, sampled=None): if "ener" in fitting_net["type"]: fitting_net["return_energy"] = True fitting = Fitting(**fitting_net) - dp_model = DPAtomicModel( - descriptor, fitting, type_map=model_params["type_map"], resuming=True - ) + dp_model = DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"]) # pairtab filepath = model_params["use_srtab"] pt_model = PairTabModel( @@ -60,7 +58,7 @@ def get_zbl_model(model_params, sampled=None): ) -def get_model(model_params, sampled=None): +def get_model(model_params): model_params = copy.deepcopy(model_params) ntypes = len(model_params["type_map"]) # descriptor @@ -79,16 +77,7 @@ def get_model(model_params, sampled=None): fitting_net["return_energy"] = True fitting = Fitting(**fitting_net) - return EnergyModel( - descriptor, - fitting, - type_map=model_params["type_map"], - type_embedding=model_params.get("type_embedding", None), - resuming=model_params.get("resuming", False), - stat_file_dir=model_params.get("stat_file_dir", None), - stat_file_path=model_params.get("stat_file_path", None), - sampled=sampled, - ) + return EnergyModel(descriptor, fitting, type_map=model_params["type_map"]) __all__ = [ diff --git a/deepmd/pt/model/model/dp_atomic_model.py b/deepmd/pt/model/model/dp_atomic_model.py index d3eb07d2a6..89b814edaa 100644 --- a/deepmd/pt/model/model/dp_atomic_model.py +++ b/deepmd/pt/model/model/dp_atomic_model.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy +import logging +import os import sys from typing import ( Dict, List, Optional, + Union, ) import torch @@ -18,6 +21,9 @@ from deepmd.pt.model.task.ener import ( # noqa # TODO: should import all fittings! InvarFitting, ) +from deepmd.pt.utils.utils import ( + dict_to_device, +) from .base_atomic_model import ( BaseAtomicModel, @@ -26,6 +32,8 @@ BaseModel, ) +log = logging.getLogger(__name__) + class DPAtomicModel(BaseModel, BaseAtomicModel): """Model give atomic prediction of some physical property. @@ -39,31 +47,9 @@ class DPAtomicModel(BaseModel, BaseAtomicModel): type_map Mapping atom type to the name (str) of the type. For example `type_map[1]` gives the name of the type 1. - type_embedding - Type embedding net - resuming - Whether to resume/fine-tune from checkpoint or not. - stat_file_dir - The directory to the state files. - stat_file_path - The path to the state files. - sampled - Sampled frames to compute the statistics. """ - # I am enough with the shit interface! - def __init__( - self, - descriptor, - fitting, - type_map: Optional[List[str]], - type_embedding: Optional[dict] = None, - resuming: bool = False, - stat_file_dir=None, - stat_file_path=None, - sampled=None, - **kwargs, - ): + def __init__(self, descriptor, fitting, type_map: Optional[List[str]]): super().__init__() ntypes = len(type_map) self.type_map = type_map @@ -72,17 +58,6 @@ def __init__( self.rcut = self.descriptor.get_rcut() self.sel = self.descriptor.get_sel() self.fitting_net = fitting - # Statistics - fitting_net = None # TODO: hack!!! not sure if it is correct. - self.compute_or_load_stat( - fitting_net, - ntypes, - resuming=resuming, - type_map=type_map, - stat_file_dir=stat_file_dir, - stat_file_path=stat_file_path, - sampled=sampled, - ) def fitting_output_def(self) -> FittingOutputDef: """Get the output def of the fitting net.""" @@ -128,13 +103,7 @@ def deserialize(cls, data) -> "DPAtomicModel": fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize( data["fitting"] ) - # TODO: dirty hack to provide type_map and avoid data stat!!! - obj = cls( - descriptor_obj, - fitting_obj, - type_map=data["type_map"], - resuming=True, - ) + obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"]) return obj def forward_atomic( @@ -191,3 +160,45 @@ def forward_atomic( aparam=aparam, ) return fit_ret + + def compute_or_load_stat( + self, + type_map: Optional[List[str]] = None, + sampled=None, + stat_file_path_dict: Optional[Dict[str, Union[str, List[str]]]] = None, + ): + """ + Compute or load the statistics parameters of the model, + such as mean and standard deviation of descriptors or the energy bias of the fitting net. + When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), + and saved in the `stat_file_path`(s). + When `sampled` is not provided, it will check the existence of `stat_file_path`(s) + and load the calculated statistics parameters. + + Parameters + ---------- + type_map + Mapping atom type to the name (str) of the type. + For example `type_map[1]` gives the name of the type 1. + sampled + The sampled data frames from different data systems. + stat_file_path_dict + The dictionary of paths to the statistics files. + """ + if sampled is not None: # move data to device + for data_sys in sampled: + dict_to_device(data_sys) + if stat_file_path_dict is not None: + if not isinstance(stat_file_path_dict["descriptor"], list): + stat_file_dir = os.path.dirname(stat_file_path_dict["descriptor"]) + else: + stat_file_dir = os.path.dirname(stat_file_path_dict["descriptor"][0]) + if not os.path.exists(stat_file_dir): + os.mkdir(stat_file_dir) + self.descriptor.compute_or_load_stat( + type_map, sampled, stat_file_path_dict["descriptor"] + ) + if self.fitting_net is not None: + self.fitting_net.compute_or_load_stat( + type_map, sampled, stat_file_path_dict["fitting_net"] + ) diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index 000746a213..51c5fcf123 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -1,19 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import logging -import os - -import numpy as np import torch -from deepmd.pt.utils import ( - env, -) -from deepmd.pt.utils.stat import ( - compute_output_stats, -) - -log = logging.getLogger(__name__) - class BaseModel(torch.nn.Module): def __init__(self): @@ -22,127 +9,26 @@ def __init__(self): def compute_or_load_stat( self, - fitting_param, - ntypes, - resuming=False, type_map=None, - stat_file_dir=None, - stat_file_path=None, sampled=None, + stat_file_path=None, ): - if fitting_param is None: - fitting_param = {} - if not resuming: - if sampled is not None: # compute stat - for sys in sampled: - for key in sys: - if isinstance(sys[key], list): - sys[key] = [item.to(env.DEVICE) for item in sys[key]] - else: - if sys[key] is not None: - sys[key] = sys[key].to(env.DEVICE) - sumr, suma, sumn, sumr2, suma2 = self.descriptor.compute_input_stats( - sampled - ) + """ + Compute or load the statistics parameters of the model, + such as mean and standard deviation of descriptors or the energy bias of the fitting net. + When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update), + and saved in the `stat_file_path`(s). + When `sampled` is not provided, it will check the existence of `stat_file_path`(s) + and load the calculated statistics parameters. - energy = [item["energy"] for item in sampled] - mixed_type = "real_natoms_vec" in sampled[0] - if mixed_type: - input_natoms = [item["real_natoms_vec"] for item in sampled] - else: - input_natoms = [item["natoms"] for item in sampled] - tmp = compute_output_stats(energy, input_natoms) - fitting_param["bias_atom_e"] = tmp[:, 0] - if stat_file_path is not None: - if not os.path.exists(stat_file_dir): - os.mkdir(stat_file_dir) - if not isinstance(stat_file_path, list): - log.info(f"Saving stat file to {stat_file_path}") - np.savez_compressed( - stat_file_path, - sumr=sumr, - suma=suma, - sumn=sumn, - sumr2=sumr2, - suma2=suma2, - bias_atom_e=fitting_param["bias_atom_e"], - type_map=type_map, - ) - else: - for ii, file_path in enumerate(stat_file_path): - log.info(f"Saving stat file to {file_path}") - np.savez_compressed( - file_path, - sumr=sumr[ii], - suma=suma[ii], - sumn=sumn[ii], - sumr2=sumr2[ii], - suma2=suma2[ii], - bias_atom_e=fitting_param["bias_atom_e"], - type_map=type_map, - ) - else: # load stat - target_type_map = type_map - if not isinstance(stat_file_path, list): - log.info(f"Loading stat file from {stat_file_path}") - stats = np.load(stat_file_path) - stat_type_map = list(stats["type_map"]) - missing_type = [ - i for i in target_type_map if i not in stat_type_map - ] - assert not missing_type, f"These type are not in stat file {stat_file_path}: {missing_type}! Please change the stat file path!" - idx_map = [stat_type_map.index(i) for i in target_type_map] - if stats["sumr"].size: - sumr, suma, sumn, sumr2, suma2 = ( - stats["sumr"][idx_map], - stats["suma"][idx_map], - stats["sumn"][idx_map], - stats["sumr2"][idx_map], - stats["suma2"][idx_map], - ) - else: - sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] - fitting_param["bias_atom_e"] = stats["bias_atom_e"][idx_map] - else: - sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] - id_bias_atom_e = None - for ii, file_path in enumerate(stat_file_path): - log.info(f"Loading stat file from {file_path}") - stats = np.load(file_path) - stat_type_map = list(stats["type_map"]) - missing_type = [ - i for i in target_type_map if i not in stat_type_map - ] - assert not missing_type, f"These type are not in stat file {file_path}: {missing_type}! Please change the stat file path!" - idx_map = [stat_type_map.index(i) for i in target_type_map] - if stats["sumr"].size: - sumr_tmp, suma_tmp, sumn_tmp, sumr2_tmp, suma2_tmp = ( - stats["sumr"][idx_map], - stats["suma"][idx_map], - stats["sumn"][idx_map], - stats["sumr2"][idx_map], - stats["suma2"][idx_map], - ) - else: - sumr_tmp, suma_tmp, sumn_tmp, sumr2_tmp, suma2_tmp = ( - [], - [], - [], - [], - [], - ) - sumr.append(sumr_tmp) - suma.append(suma_tmp) - sumn.append(sumn_tmp) - sumr2.append(sumr2_tmp) - suma2.append(suma2_tmp) - fitting_param["bias_atom_e"] = stats["bias_atom_e"][idx_map] - if id_bias_atom_e is None: - id_bias_atom_e = fitting_param["bias_atom_e"] - else: - assert ( - id_bias_atom_e == fitting_param["bias_atom_e"] - ).all(), "bias_atom_e in stat files are not consistent!" - self.descriptor.init_desc_stat(sumr, suma, sumn, sumr2, suma2) - else: # resuming for checkpoint; init model params from scratch - fitting_param["bias_atom_e"] = [0.0] * ntypes + Parameters + ---------- + type_map + Mapping atom type to the name (str) of the type. + For example `type_map[1]` gives the name of the type 1. + sampled + The sampled data frames from different data systems. + stat_file_path + The path to the statistics files. + """ + raise NotImplementedError diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index d73c33545e..c8ade925c0 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -32,6 +32,9 @@ DEFAULT_PRECISION, PRECISION_DICT, ) +from deepmd.pt.utils.stat import ( + compute_output_bias, +) from deepmd.pt.utils.utils import ( to_numpy_array, to_torch_tensor, @@ -173,7 +176,7 @@ def output_def(self) -> FittingOutputDef: def __setitem__(self, key, value): if key in ["bias_atom_e"]: - # correct bias_atom_e shape. user may provide stupid shape + value = value.view([self.ntypes, self.dim_out]) self.bias_atom_e = value elif key in ["fparam_avg"]: self.fparam_avg = value @@ -200,6 +203,33 @@ def __getitem__(self, key): else: raise KeyError(key) + @property + def data_stat_key(self): + """ + Get the keys for the data statistic of the fitting. + Return a list of statistic names needed, such as "bias_atom_e". + """ + return ["bias_atom_e"] + + def compute_output_stats(self, merged): + energy = [item["energy"] for item in merged] + mixed_type = "real_natoms_vec" in merged[0] + if mixed_type: + input_natoms = [item["real_natoms_vec"] for item in merged] + else: + input_natoms = [item["natoms"] for item in merged] + tmp = compute_output_bias(energy, input_natoms) + bias_atom_e = tmp[:, 0] + return {"bias_atom_e": bias_atom_e} + + def init_fitting_stat(self, bias_atom_e=None, **kwargs): + assert True not in [x is None for x in [bias_atom_e]] + self.bias_atom_e.copy_( + torch.tensor(bias_atom_e, device=env.DEVICE).view( + [self.ntypes, self.dim_out] + ) + ) + def serialize(self) -> dict: """Serialize the fitting to dict.""" return { @@ -394,6 +424,16 @@ def __init__( **kwargs, ) + @classmethod + def get_stat_name(cls, ntypes, type_name="ener", **kwargs): + """ + Get the name for the statistic file of the fitting. + Usually use the combination of fitting net name and ntypes as the statistic file name. + """ + fitting_type = type_name + assert fitting_type in ["ener"] + return f"stat_file_fitting_ener_ntypes{ntypes}.npz" + @Fitting.register("direct_force") @Fitting.register("direct_force_ener") @@ -486,6 +526,16 @@ def serialize(self) -> dict: def deserialize(cls) -> "EnergyFittingNetDirect": raise NotImplementedError + @classmethod + def get_stat_name(cls, ntypes, type_name="ener", **kwargs): + """ + Get the name for the statistic file of the fitting. + Usually use the combination of fitting net name and ntypes as the statistic file name. + """ + fitting_type = type_name + assert fitting_type in ["direct_force", "direct_force_ener"] + return f"stat_file_fitting_direct_ntypes{ntypes}.npz" + def forward( self, inputs: torch.Tensor, diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index b03aee7539..360f545975 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -2,6 +2,9 @@ import logging from typing import ( Callable, + List, + Optional, + Union, ) import numpy as np @@ -94,6 +97,107 @@ def share_params(self, base_class, shared_level, resume=False): else: raise NotImplementedError + @classmethod + def get_stat_name(cls, ntypes, type_name="ener", **kwargs): + """ + Get the name for the statistic file of the fitting. + Usually use the combination of fitting net name and ntypes as the statistic file name. + """ + if cls is not Fitting: + raise NotImplementedError("get_stat_name is not implemented!") + fitting_type = type_name + return Fitting.__plugins.plugins[fitting_type].get_stat_name( + ntypes, type_name, **kwargs + ) + + @property + def data_stat_key(self): + """ + Get the keys for the data statistic of the fitting. + Return a list of statistic names needed, such as "bias_atom_e". + """ + raise NotImplementedError("data_stat_key is not implemented!") + + def compute_or_load_stat( + self, + type_map: List[str], + sampled=None, + stat_file_path: Optional[Union[str, List[str]]] = None, + ): + """ + Compute or load the statistics parameters of the fitting net. + Calculate and save the output bias to `stat_file_path` + if `sampled` is not None, otherwise load them from `stat_file_path`. + + Parameters + ---------- + type_map + Mapping atom type to the name (str) of the type. + For example `type_map[1]` gives the name of the type 1. + sampled + The sampled data frames from different data systems. + stat_file_path + The path to the statistics files. + """ + fitting_stat_key = self.data_stat_key + if sampled is not None: + tmp_dict = self.compute_output_stats(sampled) + result_dict = {key: tmp_dict[key] for key in fitting_stat_key} + result_dict["type_map"] = type_map + self.save_stats(result_dict, stat_file_path) + else: # load the statistics results + assert stat_file_path is not None, "No stat file to load!" + result_dict = self.load_stats(type_map, stat_file_path) + self.init_fitting_stat(**result_dict) + + def save_stats(self, result_dict, stat_file_path: str): + """ + Save the statistics results to `stat_file_path`. + + Parameters + ---------- + result_dict + The dictionary of statistics results. + stat_file_path + The path to the statistics file(s). + """ + log.info(f"Saving stat file to {stat_file_path}") + np.savez_compressed(stat_file_path, **result_dict) + + def load_stats(self, type_map, stat_file_path: str): + """ + Load the statistics results to `stat_file_path`. + + Parameters + ---------- + type_map + Mapping atom type to the name (str) of the type. + For example `type_map[1]` gives the name of the type 1. + stat_file_path + The path to the statistics file(s). + + Returns + ------- + result_dict + The dictionary of statistics results. + """ + fitting_stat_key = self.data_stat_key + target_type_map = type_map + log.info(f"Loading stat file from {stat_file_path}") + stats = np.load(stat_file_path) + stat_type_map = list(stats["type_map"]) + missing_type = [i for i in target_type_map if i not in stat_type_map] + assert not missing_type, ( + f"These type are not in stat file {stat_file_path}: {missing_type}! " + f"Please change the stat file path!" + ) + idx_map = [stat_type_map.index(i) for i in target_type_map] + if stats[fitting_stat_key[0]].size: # not empty + result_dict = {key: stats[key][idx_map] for key in fitting_stat_key} + else: + result_dict = {key: [] for key in fitting_stat_key} + return result_dict + def change_energy_bias( self, config, model, old_type_map, new_type_map, bias_shift="delta", ntest=10 ): diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 02367f4aee..b2cac5a5eb 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -67,7 +67,8 @@ def __init__( self, config: Dict[str, Any], training_data, - sampled, + sampled=None, + stat_file_path=None, validation_data=None, init_model=None, restart_model=None, @@ -91,6 +92,8 @@ def __init__( self.model_keys = ( list(model_params["model_dict"]) if self.multi_task else ["Default"] ) + if self.multi_task and sampled is None: + sampled = {key: None for key in self.model_keys} self.rank = dist.get_rank() if dist.is_initialized() else 0 self.world_size = dist.get_world_size() if dist.is_initialized() else 1 self.num_model = len(self.model_keys) @@ -178,8 +181,14 @@ def get_data_loader(_training_data, _validation_data, _training_params): valid_numb_batch, ) - def get_single_model(_model_params, _sampled): - model = get_model(deepcopy(_model_params), _sampled).to(DEVICE) + def get_single_model(_model_params, _sampled, _stat_file_path): + model = get_model(deepcopy(_model_params)).to(DEVICE) + if not model_params.get("resuming", False): + model.compute_or_load_stat( + type_map=_model_params["type_map"], + sampled=_sampled, + stat_file_path_dict=_stat_file_path, + ) return model def get_lr(lr_params): @@ -229,7 +238,7 @@ def get_loss(loss_params, start_lr, _ntypes): self.validation_data, self.valid_numb_batch, ) = get_data_loader(training_data, validation_data, training_params) - self.model = get_single_model(model_params, sampled) + self.model = get_single_model(model_params, sampled, stat_file_path) else: ( self.training_dataloader, @@ -252,7 +261,9 @@ def get_loss(loss_params, start_lr, _ntypes): training_params["data_dict"][model_key], ) self.model[model_key] = get_single_model( - model_params["model_dict"][model_key], sampled[model_key] + model_params["model_dict"][model_key], + sampled[model_key], + stat_file_path[model_key], ) # Learning rate diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 5fde03c74a..932ba9a409 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +import os import numpy as np import torch @@ -86,8 +87,8 @@ def make_stat_input(datasets, dataloaders, nbatches): return lst -def compute_output_stats(energy, natoms, rcond=None): - """Update mean and stddev for descriptor elements. +def compute_output_bias(energy, natoms, rcond=None): + """Update output bias for fitting net. Args: - energy: Batched energy with shape [nframes, 1]. @@ -104,3 +105,32 @@ def compute_output_stats(energy, natoms, rcond=None): sys_tynatom = torch.cat(natoms)[:, 2:].cpu() energy_coef, _, _, _ = np.linalg.lstsq(sys_tynatom, sys_ener, rcond) return energy_coef + + +def process_stat_path( + stat_file_dict, stat_file_dir, model_params_dict, descriptor_cls, fitting_cls +): + if stat_file_dict is None: + stat_file_dict = {} + if "descriptor" in model_params_dict: + default_stat_file_name_descrpt = descriptor_cls.get_stat_name( + len(model_params_dict["type_map"]), + model_params_dict["descriptor"]["type"], + **model_params_dict["descriptor"], + ) + stat_file_dict["descriptor"] = default_stat_file_name_descrpt + if "fitting_net" in model_params_dict: + default_stat_file_name_fitting = fitting_cls.get_stat_name( + len(model_params_dict["type_map"]), + model_params_dict["fitting_net"].get("type", "ener"), + **model_params_dict["fitting_net"], + ) + stat_file_dict["fitting_net"] = default_stat_file_name_fitting + stat_file_path = { + key: os.path.join(stat_file_dir, stat_file_dict[key]) for key in stat_file_dict + } + + has_stat_file_path_list = [ + os.path.exists(stat_file_path[key]) for key in stat_file_dict + ] + return stat_file_path, False not in has_stat_file_path_list diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 2b96925a51..d6621f7b4c 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -81,3 +81,12 @@ def to_torch_tensor( if prec is None: raise ValueError(f"unknown precision {xx.dtype}") return torch.tensor(xx, dtype=prec, device=DEVICE) + + +def dict_to_device(sample_dict): + for key in sample_dict: + if isinstance(sample_dict[key], list): + sample_dict[key] = [item.to(DEVICE) for item in sample_dict[key]] + else: + if sample_dict[key] is not None: + sample_dict[key] = sample_dict[key].to(DEVICE) diff --git a/source/tests/pt/model/test_autodiff.py b/source/tests/pt/model/test_autodiff.py index 24dc69458d..e69e894af6 100644 --- a/source/tests/pt/model/test_autodiff.py +++ b/source/tests/pt/model/test_autodiff.py @@ -17,7 +17,6 @@ from .test_permutation import ( eval_model, - make_sample, model_dpa1, model_dpa2, model_se_e2_a, @@ -135,33 +134,29 @@ def ff(bb): class TestEnergyModelSeAForce(unittest.TestCase, ForceTest): def setUp(self): model_params = copy.deepcopy(model_se_e2_a) - sampled = make_sample(model_params) self.type_split = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelSeAVirial(unittest.TestCase, VirialTest): def setUp(self): model_params = copy.deepcopy(model_se_e2_a) - sampled = make_sample(model_params) self.type_split = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1Force(unittest.TestCase, ForceTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1Virial(unittest.TestCase, VirialTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA2Force(unittest.TestCase, ForceTest): @@ -173,10 +168,9 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPAUniVirial(unittest.TestCase, VirialTest): @@ -188,23 +182,20 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelZBLForce(unittest.TestCase, ForceTest): def setUp(self): model_params = copy.deepcopy(model_zbl) - sampled = make_sample(model_params) self.type_split = False - self.model = get_zbl_model(model_params, sampled).to(env.DEVICE) + self.model = get_zbl_model(model_params).to(env.DEVICE) class TestEnergyModelZBLVirial(unittest.TestCase, VirialTest): def setUp(self): model_params = copy.deepcopy(model_zbl) - sampled = make_sample(model_params) self.type_split = False - self.model = get_zbl_model(model_params, sampled).to(env.DEVICE) + self.model = get_zbl_model(model_params).to(env.DEVICE) diff --git a/source/tests/pt/model/test_dp_atomic_model.py b/source/tests/pt/model/test_dp_atomic_model.py index 2960cb97cc..ef25e574d4 100644 --- a/source/tests/pt/model/test_dp_atomic_model.py +++ b/source/tests/pt/model/test_dp_atomic_model.py @@ -50,8 +50,7 @@ def test_self_consistency(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - md0 = DPAtomicModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md0 = DPAtomicModel(ds, ft, type_map=type_map).to(env.DEVICE) md1 = DPAtomicModel.deserialize(md0.serialize()).to(env.DEVICE) args = [ to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] @@ -107,6 +106,5 @@ def test_jit(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - md0 = DPAtomicModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md0 = DPAtomicModel(ds, ft, type_map=type_map).to(env.DEVICE) torch.jit.script(md0) diff --git a/source/tests/pt/model/test_dp_model.py b/source/tests/pt/model/test_dp_model.py index 51aa5d92f6..6e009d3934 100644 --- a/source/tests/pt/model/test_dp_model.py +++ b/source/tests/pt/model/test_dp_model.py @@ -56,8 +56,7 @@ def test_self_consistency(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md0 = DPModel(ds, ft, type_map=type_map).to(env.DEVICE) md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) args = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] ret0 = md0.forward_common(*args) @@ -206,8 +205,7 @@ def test_self_consistency(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md0 = DPModel(ds, ft, type_map=type_map).to(env.DEVICE) md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) args = [ to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] @@ -285,8 +283,7 @@ def test_jit(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md0 = DPModel(ds, ft, type_map=type_map).to(env.DEVICE) torch.jit.script(md0) @@ -334,8 +331,7 @@ def setUp(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - self.md = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + self.md = DPModel(ds, ft, type_map=type_map).to(env.DEVICE) def test_nlist_eq(self): # n_nnei == nnei @@ -408,8 +404,7 @@ def test_self_consistency(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - md0 = EnergyModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE) md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE) args = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] ret0 = md0.forward(*args) @@ -480,8 +475,7 @@ def test_self_consistency(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - md0 = EnergyModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE) md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE) args = [ to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] @@ -526,6 +520,5 @@ def test_jit(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - md0 = EnergyModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE) torch.jit.script(md0) diff --git a/source/tests/pt/model/test_ener_fitting.py b/source/tests/pt/model/test_ener_fitting.py index cbddf34dd6..42aeeff16a 100644 --- a/source/tests/pt/model/test_ener_fitting.py +++ b/source/tests/pt/model/test_ener_fitting.py @@ -178,4 +178,6 @@ def test_get_set(self): "aparam_inv_std", ]: ifn0[ii] = torch.tensor(foo, dtype=dtype, device=env.DEVICE) - np.testing.assert_allclose(foo, ifn0[ii].detach().cpu().numpy()) + np.testing.assert_allclose( + foo, np.reshape(ifn0[ii].detach().cpu().numpy(), foo.shape) + ) diff --git a/source/tests/pt/model/test_force_grad.py b/source/tests/pt/model/test_force_grad.py index 1ea4321d21..0a4dc32d9f 100644 --- a/source/tests/pt/model/test_force_grad.py +++ b/source/tests/pt/model/test_force_grad.py @@ -19,15 +19,9 @@ from deepmd.pt.utils import ( env, ) -from deepmd.pt.utils.dataloader import ( - DpLoaderSet, -) from deepmd.pt.utils.dataset import ( DeepmdDataSystem, ) -from deepmd.pt.utils.stat import ( - make_stat_input, -) class CheckSymmetry(DeepmdDataSystem): @@ -75,18 +69,7 @@ def setUp(self): self.get_model() def get_model(self): - training_systems = self.config["training"]["training_data"]["systems"] - model_params = self.config["model"] - data_stat_nbatch = model_params.get("data_stat_nbatch", 10) - train_data = DpLoaderSet( - training_systems, - self.config["training"]["training_data"]["batch_size"], - model_params, - ) - sampled = make_stat_input( - train_data.systems, train_data.dataloaders, data_stat_nbatch - ) - self.model = get_model(self.config["model"], sampled).to(env.DEVICE) + self.model = get_model(self.config["model"]).to(env.DEVICE) def get_dataset(self, system_index=0, batch_index=0): systems = self.config["training"]["training_data"]["systems"] diff --git a/source/tests/pt/model/test_linear_atomic_model.py b/source/tests/pt/model/test_linear_atomic_model.py index e9090de86a..e0247f911f 100644 --- a/source/tests/pt/model/test_linear_atomic_model.py +++ b/source/tests/pt/model/test_linear_atomic_model.py @@ -74,9 +74,7 @@ def test_pairwise(self, mock_loadtxt): type_map = ["foo", "bar"] zbl_model = PairTabModel(tab_file=file_path, rcut=0.3, sel=2) - dp_model = DPAtomicModel(ds, ft, type_map=type_map, resuming=True).to( - env.DEVICE - ) + dp_model = DPAtomicModel(ds, ft, type_map=type_map).to(env.DEVICE) wgt_model = DPZBLLinearAtomicModel( dp_model, zbl_model, @@ -142,9 +140,7 @@ def setUp(self, mock_loadtxt): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - dp_model = DPAtomicModel(ds, ft, type_map=type_map, resuming=True).to( - env.DEVICE - ) + dp_model = DPAtomicModel(ds, ft, type_map=type_map).to(env.DEVICE) zbl_model = PairTabModel(file_path, self.rcut, sum(self.sel)) self.md0 = DPZBLLinearAtomicModel( dp_model, diff --git a/source/tests/pt/model/test_model.py b/source/tests/pt/model/test_model.py index bb99759d16..522b30b2df 100644 --- a/source/tests/pt/model/test_model.py +++ b/source/tests/pt/model/test_model.py @@ -30,9 +30,6 @@ DEVICE, ) from deepmd.pt.utils.learning_rate import LearningRateExp as MyLRExp -from deepmd.pt.utils.stat import ( - make_stat_input, -) from deepmd.tf.common import ( data_requirement, expand_sys_str, @@ -282,9 +279,6 @@ def test_consistency(self): "type_map": self.type_map, }, ) - sampled = make_stat_input( - my_ds.systems, my_ds.dataloaders, self.data_stat_nbatch - ) my_model = get_model( model_params={ "descriptor": { @@ -299,7 +293,6 @@ def test_consistency(self): "data_stat_nbatch": self.data_stat_nbatch, "type_map": self.type_map, }, - sampled=sampled, ) my_model.to(DEVICE) my_lr = MyLRExp(self.start_lr, self.stop_lr, self.decay_steps, self.stop_steps) diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index 2301b6ea10..15359f873a 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -1,9 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy import unittest -from pathlib import ( - Path, -) import torch @@ -17,12 +14,6 @@ from deepmd.pt.utils import ( env, ) -from deepmd.pt.utils.dataloader import ( - DpLoaderSet, -) -from deepmd.pt.utils.stat import ( - make_stat_input, -) dtype = torch.float64 @@ -205,22 +196,6 @@ } -def make_sample(model_params): - training_systems = [ - str(Path(__file__).parent / "water/data/data_0"), - ] - data_stat_nbatch = model_params.get("data_stat_nbatch", 10) - train_data = DpLoaderSet( - training_systems, - batch_size=4, - model_params=model_params.copy(), - ) - sampled = make_stat_input( - train_data.systems, train_data.dataloaders, data_stat_nbatch - ) - return sampled - - class PermutationTest: def test( self, @@ -262,17 +237,15 @@ def test( class TestEnergyModelSeA(unittest.TestCase, PermutationTest): def setUp(self): model_params = copy.deepcopy(model_se_e2_a) - sampled = make_sample(model_params) self.type_split = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1(unittest.TestCase, PermutationTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA2(unittest.TestCase, PermutationTest): @@ -284,10 +257,9 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestForceModelDPA2(unittest.TestCase, PermutationTest): @@ -299,21 +271,19 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) model_params["fitting_net"]["type"] = "direct_force_ener" self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") class TestEnergyModelHybrid(unittest.TestCase, PermutationTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") @@ -321,25 +291,22 @@ class TestForceModelHybrid(unittest.TestCase, PermutationTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) model_params["fitting_net"]["type"] = "direct_force_ener" - sampled = make_sample(model_params) self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelZBL(unittest.TestCase, PermutationTest): def setUp(self): model_params = copy.deepcopy(model_zbl) - sampled = make_sample(model_params) self.type_split = False - self.model = get_zbl_model(model_params, sampled).to(env.DEVICE) + self.model = get_zbl_model(model_params).to(env.DEVICE) # class TestEnergyFoo(unittest.TestCase): # def test(self): # model_params = model_dpau -# sampled = make_sample(model_params) -# self.model = EnergyModelDPAUni(model_params, sampled).to(env.DEVICE) +# self.model = EnergyModelDPAUni(model_params).to(env.DEVICE) # natoms = 5 # cell = torch.rand([3, 3], dtype=dtype) diff --git a/source/tests/pt/model/test_permutation_denoise.py b/source/tests/pt/model/test_permutation_denoise.py index 6dd61ab7e4..3b6be0c495 100644 --- a/source/tests/pt/model/test_permutation_denoise.py +++ b/source/tests/pt/model/test_permutation_denoise.py @@ -15,7 +15,6 @@ ) from .test_permutation import ( # model_dpau, - make_sample, model_dpa1, model_dpa2, model_hybrid, @@ -70,9 +69,8 @@ def test( class TestDenoiseModelDPA1(unittest.TestCase, PermutationDenoiseTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("support of the denoise is temporally disabled") @@ -85,19 +83,19 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model( + model_params, + ).to(env.DEVICE) # @unittest.skip("hybrid not supported at the moment") # class TestDenoiseModelHybrid(unittest.TestCase, TestPermutationDenoise): # def setUp(self): # model_params = copy.deepcopy(model_hybrid_denoise) -# sampled = make_sample(model_params) # self.type_split = True -# self.model = get_model(model_params, sampled).to(env.DEVICE) +# self.model = get_model(model_params).to(env.DEVICE) if __name__ == "__main__": diff --git a/source/tests/pt/model/test_rot.py b/source/tests/pt/model/test_rot.py index 982753e94f..780d193ebd 100644 --- a/source/tests/pt/model/test_rot.py +++ b/source/tests/pt/model/test_rot.py @@ -16,7 +16,6 @@ ) from .test_permutation import ( # model_dpau, - make_sample, model_dpa1, model_dpa2, model_hybrid, @@ -114,17 +113,15 @@ def test( class TestEnergyModelSeA(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_se_e2_a) - sampled = make_sample(model_params) self.type_split = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA2(unittest.TestCase, RotTest): @@ -136,10 +133,9 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestForceModelDPA2(unittest.TestCase, RotTest): @@ -151,21 +147,19 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) model_params["fitting_net"]["type"] = "direct_force_ener" self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") class TestEnergyModelHybrid(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") @@ -173,18 +167,16 @@ class TestForceModelHybrid(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) model_params["fitting_net"]["type"] = "direct_force_ener" - sampled = make_sample(model_params) self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelZBL(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_zbl) - sampled = make_sample(model_params) self.type_split = False - self.model = get_zbl_model(model_params, sampled).to(env.DEVICE) + self.model = get_zbl_model(model_params).to(env.DEVICE) if __name__ == "__main__": diff --git a/source/tests/pt/model/test_rot_denoise.py b/source/tests/pt/model/test_rot_denoise.py index 2cbfd8fd38..e4ae02f630 100644 --- a/source/tests/pt/model/test_rot_denoise.py +++ b/source/tests/pt/model/test_rot_denoise.py @@ -15,7 +15,6 @@ ) from .test_permutation_denoise import ( - make_sample, model_dpa1, model_dpa2, ) @@ -101,9 +100,8 @@ def test( class TestDenoiseModelDPA1(unittest.TestCase, RotDenoiseTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("support of the denoise is temporally disabled") @@ -116,19 +114,17 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) # @unittest.skip("hybrid not supported at the moment") # class TestEnergyModelHybrid(unittest.TestCase, TestRotDenoise): # def setUp(self): # model_params = copy.deepcopy(model_hybrid_denoise) -# sampled = make_sample(model_params) # self.type_split = True -# self.model = get_model(model_params, sampled).to(env.DEVICE) +# self.model = get_model(model_params).to(env.DEVICE) if __name__ == "__main__": diff --git a/source/tests/pt/model/test_rotation.py b/source/tests/pt/model/test_rotation.py index a62e04eb89..5314959673 100644 --- a/source/tests/pt/model/test_rotation.py +++ b/source/tests/pt/model/test_rotation.py @@ -21,15 +21,9 @@ from deepmd.pt.utils import ( env, ) -from deepmd.pt.utils.dataloader import ( - DpLoaderSet, -) from deepmd.pt.utils.dataset import ( DeepmdDataSystem, ) -from deepmd.pt.utils.stat import ( - make_stat_input, -) class CheckSymmetry(DeepmdDataSystem): @@ -82,18 +76,7 @@ def setUp(self): self.get_model() def get_model(self): - training_systems = self.config["training"]["training_data"]["systems"] - model_params = self.config["model"] - data_stat_nbatch = model_params.get("data_stat_nbatch", 10) - train_data = DpLoaderSet( - training_systems, - self.config["training"]["training_data"]["batch_size"], - model_params, - ) - sampled = make_stat_input( - train_data.systems, train_data.dataloaders, data_stat_nbatch - ) - self.model = get_model(self.config["model"], sampled).to(env.DEVICE) + self.model = get_model(self.config["model"]).to(env.DEVICE) def get_dataset(self, system_index=0, batch_index=0): systems = self.config["training"]["training_data"]["systems"] diff --git a/source/tests/pt/model/test_saveload_dpa1.py b/source/tests/pt/model/test_saveload_dpa1.py index 1b4c41a204..64229b8e9e 100644 --- a/source/tests/pt/model/test_saveload_dpa1.py +++ b/source/tests/pt/model/test_saveload_dpa1.py @@ -109,14 +109,13 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"): def create_wrapper(self, read: bool): model_config = copy.deepcopy(self.config["model"]) - sampled = copy.deepcopy(self.sampled) model_config["resuming"] = read model_config["stat_file_dir"] = "stat_files" model_config["stat_file"] = "stat.npz" model_config["stat_file_path"] = os.path.join( model_config["stat_file_dir"], model_config["stat_file"] ) - model = get_model(model_config, sampled).to(env.DEVICE) + model = get_model(model_config).to(env.DEVICE) return ModelWrapper(model, self.loss) def get_data(self): diff --git a/source/tests/pt/model/test_saveload_se_e2_a.py b/source/tests/pt/model/test_saveload_se_e2_a.py index 7f8364a16f..0632e30b5b 100644 --- a/source/tests/pt/model/test_saveload_se_e2_a.py +++ b/source/tests/pt/model/test_saveload_se_e2_a.py @@ -109,8 +109,7 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"): def create_wrapper(self): model_config = copy.deepcopy(self.config["model"]) - sampled = copy.deepcopy(self.sampled) - model = get_model(model_config, sampled).to(env.DEVICE) + model = get_model(model_config).to(env.DEVICE) return ModelWrapper(model, self.loss) def get_data(self): diff --git a/source/tests/pt/model/test_smooth.py b/source/tests/pt/model/test_smooth.py index f2f45c74aa..fa9042a932 100644 --- a/source/tests/pt/model/test_smooth.py +++ b/source/tests/pt/model/test_smooth.py @@ -16,7 +16,6 @@ ) from .test_permutation import ( # model_dpau, - make_sample, model_dpa1, model_dpa2, model_hybrid, @@ -125,9 +124,8 @@ def compare(ret0, ret1): class TestEnergyModelSeA(unittest.TestCase, SmoothTest): def setUp(self): model_params = copy.deepcopy(model_se_e2_a) - sampled = make_sample(model_params) self.type_split = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None @@ -135,9 +133,8 @@ def setUp(self): class TestEnergyModelDPA1(unittest.TestCase, SmoothTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) # less degree of smoothness, # error can be systematically removed by reducing epsilon self.epsilon = 1e-5 @@ -160,9 +157,8 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = 1e-5, 1e-4 @@ -177,10 +173,9 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None @@ -195,10 +190,9 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None @@ -206,26 +200,23 @@ def setUp(self): class TestEnergyModelHybrid(unittest.TestCase, SmoothTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None class TestEnergyModelZBL(unittest.TestCase, SmoothTest): def setUp(self): model_params = copy.deepcopy(model_zbl) - sampled = make_sample(model_params) self.type_split = False - self.model = get_zbl_model(model_params, sampled).to(env.DEVICE) + self.model = get_zbl_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None # class TestEnergyFoo(unittest.TestCase): # def test(self): # model_params = model_dpau -# sampled = make_sample(model_params) -# self.model = EnergyModelDPAUni(model_params, sampled).to(env.DEVICE) +# self.model = EnergyModelDPAUni(model_params).to(env.DEVICE) # natoms = 5 # cell = torch.rand([3, 3], dtype=dtype) diff --git a/source/tests/pt/model/test_smooth_denoise.py b/source/tests/pt/model/test_smooth_denoise.py index de89f8dccc..777d288f3c 100644 --- a/source/tests/pt/model/test_smooth_denoise.py +++ b/source/tests/pt/model/test_smooth_denoise.py @@ -15,7 +15,6 @@ ) from .test_permutation_denoise import ( - make_sample, model_dpa2, ) @@ -106,12 +105,11 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) model_params["descriptor"]["sel"] = 8 model_params["descriptor"]["rcut_smth"] = 3.5 self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None self.epsilon = 1e-7 self.aprec = 1e-5 @@ -127,11 +125,10 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) # model_params["descriptor"]["combine_grrg"] = True self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None self.epsilon = 1e-7 self.aprec = 1e-5 @@ -141,9 +138,8 @@ def setUp(self): # class TestDenoiseModelHybrid(unittest.TestCase, TestSmoothDenoise): # def setUp(self): # model_params = copy.deepcopy(model_hybrid_denoise) -# sampled = make_sample(model_params) # self.type_split = True -# self.model = get_model(model_params, sampled).to(env.DEVICE) +# self.model = get_model(model_params).to(env.DEVICE) # self.epsilon, self.aprec = None, None # self.epsilon = 1e-7 # self.aprec = 1e-5 diff --git a/source/tests/pt/model/test_trans.py b/source/tests/pt/model/test_trans.py index 967d505c6d..a99d6c893f 100644 --- a/source/tests/pt/model/test_trans.py +++ b/source/tests/pt/model/test_trans.py @@ -16,7 +16,6 @@ ) from .test_permutation import ( # model_dpau, - make_sample, model_dpa1, model_dpa2, model_hybrid, @@ -70,17 +69,15 @@ def test( class TestEnergyModelSeA(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_se_e2_a) - sampled = make_sample(model_params) self.type_split = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA2(unittest.TestCase, TransTest): @@ -92,10 +89,9 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestForceModelDPA2(unittest.TestCase, TransTest): @@ -107,21 +103,19 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) model_params["fitting_net"]["type"] = "direct_force_ener" self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") class TestEnergyModelHybrid(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") @@ -129,18 +123,16 @@ class TestForceModelHybrid(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) model_params["fitting_net"]["type"] = "direct_force_ener" - sampled = make_sample(model_params) self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelZBL(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_zbl) - sampled = make_sample(model_params) self.type_split = False - self.model = get_zbl_model(model_params, sampled).to(env.DEVICE) + self.model = get_zbl_model(model_params).to(env.DEVICE) if __name__ == "__main__": diff --git a/source/tests/pt/model/test_trans_denoise.py b/source/tests/pt/model/test_trans_denoise.py index 88b926a3ae..9ba93a244a 100644 --- a/source/tests/pt/model/test_trans_denoise.py +++ b/source/tests/pt/model/test_trans_denoise.py @@ -15,7 +15,6 @@ ) from .test_permutation_denoise import ( - make_sample, model_dpa1, model_dpa2, model_hybrid, @@ -60,9 +59,8 @@ def test( class TestDenoiseModelDPA1(unittest.TestCase, TransDenoiseTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("support of the denoise is temporally disabled") @@ -75,19 +73,17 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") class TestDenoiseModelHybrid(unittest.TestCase, TransDenoiseTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) if __name__ == "__main__": diff --git a/source/tests/pt/model/test_unused_params.py b/source/tests/pt/model/test_unused_params.py index a924979466..f69d8ac835 100644 --- a/source/tests/pt/model/test_unused_params.py +++ b/source/tests/pt/model/test_unused_params.py @@ -15,7 +15,6 @@ ) from .test_permutation import ( - make_sample, model_dpa2, ) @@ -57,8 +56,7 @@ def test_unused(self): self._test_unused(model) def _test_unused(self, model_params): - sampled = make_sample(model_params) - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) natoms = 5 cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE) cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py index 08fc12ff11..240c354a69 100644 --- a/source/tests/pt/test_stat.py +++ b/source/tests/pt/test_stat.py @@ -19,7 +19,7 @@ DpLoaderSet, ) from deepmd.pt.utils.stat import ( - compute_output_stats, + compute_output_bias, ) from deepmd.pt.utils.stat import make_stat_input as my_make from deepmd.tf.common import ( @@ -124,7 +124,7 @@ def my_merge(energy, natoms): energy, natoms = my_merge(energy, natoms) dp_fn = EnerFitting(self.dp_d, self.n_neuron) dp_fn.compute_output_stats(self.dp_sampled) - bias_atom_e = compute_output_stats(energy, natoms) + bias_atom_e = compute_output_bias(energy, natoms) self.assertTrue(np.allclose(dp_fn.bias_atom_e, bias_atom_e[:, 0])) # temporarily delete this function for performance of seeds in tf and pytorch may be different @@ -172,8 +172,8 @@ def test_descriptor(self): ]: if key in sys.keys(): sys[key] = sys[key].to(env.DEVICE) - sumr, suma, sumn, sumr2, suma2 = my_en.compute_input_stats(sampled) - my_en.init_desc_stat(sumr, suma, sumn, sumr2, suma2) + stat_dict = my_en.compute_input_stats(sampled) + my_en.init_desc_stat(**stat_dict) my_en.mean = my_en.mean my_en.stddev = my_en.stddev self.assertTrue(