diff --git a/deepmd/dpmodel/__init__.py b/deepmd/dpmodel/__init__.py index 5a83bb7bd4..c005223312 100644 --- a/deepmd/dpmodel/__init__.py +++ b/deepmd/dpmodel/__init__.py @@ -4,6 +4,9 @@ PRECISION_DICT, NativeOP, ) +from .descriptor import ( + DescrptSeA, +) from .model import ( DPAtomicModel, DPModel, @@ -17,6 +20,13 @@ get_reduce_name, model_check_output, ) +from .utils import ( + EmbeddingNet, + EnvMat, + FittingNet, + NativeLayer, + NativeNet, +) __all__ = [ "DPModel", @@ -24,6 +34,12 @@ "PRECISION_DICT", "DEFAULT_PRECISION", "NativeOP", + "EnvMat", + "NativeLayer", + "NativeNet", + "EmbeddingNet", + "FittingNet", + "DescrptSeA", "ModelOutputDef", "FittingOutputDef", "OutputVariableDef", diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 4ba9e17b52..ad553a57ba 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -48,7 +48,7 @@ def __init__( assert not self.multi_task, "multitask mode currently not supported!" self.type_split = self.input_param["descriptor"]["type"] in ["se_e2_a"] self.type_map = self.input_param["type_map"] - self.dp = ModelWrapper(get_model(self.input_param, None).to(DEVICE)) + self.dp = ModelWrapper(get_model(self.input_param).to(DEVICE)) self.dp.load_state_dict(state_dict) self.rcut = self.dp.model["Default"].descriptor.get_rcut() self.sec = np.cumsum(self.dp.model["Default"].descriptor.get_sel()) diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index b4e866bb11..f1c9ca4829 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -57,11 +57,19 @@ class SomeDescript(Descriptor): @classmethod def get_stat_name(cls, config): + """Get the name for the statistic file of the descriptor.""" descrpt_type = config["type"] return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config) + @classmethod + def get_data_stat_key(cls, config): + """Get the keys for the data statistic of the descriptor.""" + descrpt_type = config["type"] + return Descriptor.__plugins.plugins[descrpt_type].get_data_stat_key(config) + @classmethod def get_data_process_key(cls, config): + """Get the keys for the data preprocess.""" descrpt_type = config["type"] return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 914c37ed51..1ac872e38e 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -127,12 +127,19 @@ def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): @classmethod def get_stat_name(cls, config): + """Get the name for the statistic file of the descriptor.""" descrpt_type = config["type"] assert descrpt_type in ["dpa1", "se_atten"] return f'stat_file_dpa1_rcut{config["rcut"]:.2f}_smth{config["rcut_smth"]:.2f}_sel{config["sel"]}.npz' + @classmethod + def get_data_stat_key(cls, config): + """Get the keys for the data statistic of the descriptor.""" + return ["sumr", "suma", "sumn", "sumr2", "suma2"] + @classmethod def get_data_process_key(cls, config): + """Get the keys for the data preprocess.""" descrpt_type = config["type"] assert descrpt_type in ["dpa1", "se_atten"] return {"sel": config["sel"], "rcut": config["rcut"]} diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index b40e466ed4..d221493e28 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -316,6 +316,7 @@ def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): @classmethod def get_stat_name(cls, config): + """Get the name for the statistic file of the descriptor.""" descrpt_type = config["type"] assert descrpt_type in ["dpa2"] return ( @@ -323,8 +324,14 @@ def get_stat_name(cls, config): f'_repformer_rcut{config["repformer_rcut"]:.2f}_smth{config["repformer_rcut_smth"]:.2f}_sel{config["repformer_nsel"]}.npz' ) + @classmethod + def get_data_stat_key(cls, config): + """Get the keys for the data statistic of the descriptor.""" + return ["sumr", "suma", "sumn", "sumr2", "suma2"] + @classmethod def get_data_process_key(cls, config): + """Get the keys for the data preprocess.""" descrpt_type = config["type"] assert descrpt_type in ["dpa2"] return { diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 23b78dcf34..d841b91ccb 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -113,12 +113,19 @@ def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): @classmethod def get_stat_name(cls, config): + """Get the name for the statistic file of the descriptor.""" descrpt_type = config["type"] assert descrpt_type in ["se_e2_a"] return f'stat_file_sea_rcut{config["rcut"]:.2f}_smth{config["rcut_smth"]:.2f}_sel{config["sel"]}.npz' + @classmethod + def get_data_stat_key(cls, config): + """Get the keys for the data statistic of the descriptor.""" + return ["sumr", "suma", "sumn", "sumr2", "suma2"] + @classmethod def get_data_process_key(cls, config): + """Get the keys for the data preprocess.""" descrpt_type = config["type"] assert descrpt_type in ["se_e2_a"] return {"sel": config["sel"], "rcut": config["rcut"]} diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index c4de02ed20..36d46e32fc 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -16,7 +16,7 @@ ) -def get_model(model_params, sampled=None): +def get_model(model_params): model_params = copy.deepcopy(model_params) ntypes = len(model_params["type_map"]) # descriptor @@ -35,16 +35,7 @@ def get_model(model_params, sampled=None): fitting_net["return_energy"] = True fitting = Fitting(**fitting_net) - return EnergyModel( - descriptor, - fitting, - type_map=model_params["type_map"], - type_embedding=model_params.get("type_embedding", None), - resuming=model_params.get("resuming", False), - stat_file_dir=model_params.get("stat_file_dir", None), - stat_file_path=model_params.get("stat_file_path", None), - sampled=sampled, - ) + return EnergyModel(descriptor, fitting, type_map=model_params["type_map"]) __all__ = [ diff --git a/deepmd/pt/model/model/dp_atomic_model.py b/deepmd/pt/model/model/dp_atomic_model.py index b2ae48628b..3e2418ed4a 100644 --- a/deepmd/pt/model/model/dp_atomic_model.py +++ b/deepmd/pt/model/model/dp_atomic_model.py @@ -39,31 +39,9 @@ class DPAtomicModel(BaseModel, BaseAtomicModel): type_map Mapping atom type to the name (str) of the type. For example `type_map[1]` gives the name of the type 1. - type_embedding - Type embedding net - resuming - Whether to resume/fine-tune from checkpoint or not. - stat_file_dir - The directory to the state files. - stat_file_path - The path to the state files. - sampled - Sampled frames to compute the statistics. """ - # I am enough with the shit interface! - def __init__( - self, - descriptor, - fitting, - type_map: Optional[List[str]], - type_embedding: Optional[dict] = None, - resuming: bool = False, - stat_file_dir=None, - stat_file_path=None, - sampled=None, - **kwargs, - ): + def __init__(self, descriptor, fitting, type_map: Optional[List[str]]): super().__init__() ntypes = len(type_map) self.type_map = type_map @@ -72,17 +50,6 @@ def __init__( self.rcut = self.descriptor.get_rcut() self.sel = self.descriptor.get_sel() self.fitting_net = fitting - # Statistics - fitting_net = None # TODO: hack!!! not sure if it is correct. - self.compute_or_load_stat( - fitting_net, - ntypes, - resuming=resuming, - type_map=type_map, - stat_file_dir=stat_file_dir, - stat_file_path=stat_file_path, - sampled=sampled, - ) def fitting_output_def(self) -> FittingOutputDef: """Get the output def of the fitting net.""" @@ -122,13 +89,7 @@ def deserialize(cls, data) -> "DPAtomicModel": fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize( data["fitting"] ) - # TODO: dirty hack to provide type_map and avoid data stat!!! - obj = cls( - descriptor_obj, - fitting_obj, - type_map=data["type_map"], - resuming=True, - ) + obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"]) return obj def forward_atomic( diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index 000746a213..3e3bfe0a86 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -5,12 +5,12 @@ import numpy as np import torch -from deepmd.pt.utils import ( - env, -) from deepmd.pt.utils.stat import ( compute_output_stats, ) +from deepmd.pt.utils.utils import ( + dict_to_device, +) log = logging.getLogger(__name__) @@ -22,78 +22,89 @@ def __init__(self): def compute_or_load_stat( self, - fitting_param, - ntypes, - resuming=False, type_map=None, stat_file_dir=None, stat_file_path=None, sampled=None, ): - if fitting_param is None: - fitting_param = {} - if not resuming: - if sampled is not None: # compute stat - for sys in sampled: - for key in sys: - if isinstance(sys[key], list): - sys[key] = [item.to(env.DEVICE) for item in sys[key]] - else: - if sys[key] is not None: - sys[key] = sys[key].to(env.DEVICE) - sumr, suma, sumn, sumr2, suma2 = self.descriptor.compute_input_stats( - sampled - ) - - energy = [item["energy"] for item in sampled] - mixed_type = "real_natoms_vec" in sampled[0] - if mixed_type: - input_natoms = [item["real_natoms_vec"] for item in sampled] + bias_atom_e = None + if sampled is not None: # compute stat + for sys in sampled: + dict_to_device(sys) + sumr, suma, sumn, sumr2, suma2 = self.descriptor.compute_input_stats( + sampled + ) + energy = [item["energy"] for item in sampled] + mixed_type = "real_natoms_vec" in sampled[0] + if mixed_type: + input_natoms = [item["real_natoms_vec"] for item in sampled] + else: + input_natoms = [item["natoms"] for item in sampled] + tmp = compute_output_stats(energy, input_natoms) + bias_atom_e = tmp[:, 0] + if stat_file_path is not None: + if not os.path.exists(stat_file_dir): + os.mkdir(stat_file_dir) + if not isinstance(stat_file_path, list): + log.info(f"Saving stat file to {stat_file_path}") + np.savez_compressed( + stat_file_path, + sumr=sumr, + suma=suma, + sumn=sumn, + sumr2=sumr2, + suma2=suma2, + bias_atom_e=bias_atom_e, + type_map=type_map, + ) else: - input_natoms = [item["natoms"] for item in sampled] - tmp = compute_output_stats(energy, input_natoms) - fitting_param["bias_atom_e"] = tmp[:, 0] - if stat_file_path is not None: - if not os.path.exists(stat_file_dir): - os.mkdir(stat_file_dir) - if not isinstance(stat_file_path, list): - log.info(f"Saving stat file to {stat_file_path}") + for ii, file_path in enumerate(stat_file_path): + log.info(f"Saving stat file to {file_path}") np.savez_compressed( - stat_file_path, - sumr=sumr, - suma=suma, - sumn=sumn, - sumr2=sumr2, - suma2=suma2, - bias_atom_e=fitting_param["bias_atom_e"], + file_path, + sumr=sumr[ii], + suma=suma[ii], + sumn=sumn[ii], + sumr2=sumr2[ii], + suma2=suma2[ii], + bias_atom_e=bias_atom_e, type_map=type_map, ) - else: - for ii, file_path in enumerate(stat_file_path): - log.info(f"Saving stat file to {file_path}") - np.savez_compressed( - file_path, - sumr=sumr[ii], - suma=suma[ii], - sumn=sumn[ii], - sumr2=sumr2[ii], - suma2=suma2[ii], - bias_atom_e=fitting_param["bias_atom_e"], - type_map=type_map, - ) - else: # load stat - target_type_map = type_map - if not isinstance(stat_file_path, list): - log.info(f"Loading stat file from {stat_file_path}") - stats = np.load(stat_file_path) + else: # load stat + target_type_map = type_map + if not isinstance(stat_file_path, list): + log.info(f"Loading stat file from {stat_file_path}") + stats = np.load(stat_file_path) + stat_type_map = list(stats["type_map"]) + missing_type = [i for i in target_type_map if i not in stat_type_map] + assert not missing_type, f"These type are not in stat file {stat_file_path}: {missing_type}! Please change the stat file path!" + idx_map = [stat_type_map.index(i) for i in target_type_map] + if stats["sumr"].size: + sumr, suma, sumn, sumr2, suma2 = ( + stats["sumr"][idx_map], + stats["suma"][idx_map], + stats["sumn"][idx_map], + stats["sumr2"][idx_map], + stats["suma2"][idx_map], + ) + else: + sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] + + bias_atom_e = stats["bias_atom_e"][idx_map] + else: + sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] + id_bias_atom_e = None + for ii, file_path in enumerate(stat_file_path): + log.info(f"Loading stat file from {file_path}") + stats = np.load(file_path) stat_type_map = list(stats["type_map"]) missing_type = [ i for i in target_type_map if i not in stat_type_map ] - assert not missing_type, f"These type are not in stat file {stat_file_path}: {missing_type}! Please change the stat file path!" + assert not missing_type, f"These type are not in stat file {file_path}: {missing_type}! Please change the stat file path!" idx_map = [stat_type_map.index(i) for i in target_type_map] if stats["sumr"].size: - sumr, suma, sumn, sumr2, suma2 = ( + sumr_tmp, suma_tmp, sumn_tmp, sumr2_tmp, suma2_tmp = ( stats["sumr"][idx_map], stats["suma"][idx_map], stats["sumn"][idx_map], @@ -101,48 +112,27 @@ def compute_or_load_stat( stats["suma2"][idx_map], ) else: - sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] - fitting_param["bias_atom_e"] = stats["bias_atom_e"][idx_map] - else: - sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] - id_bias_atom_e = None - for ii, file_path in enumerate(stat_file_path): - log.info(f"Loading stat file from {file_path}") - stats = np.load(file_path) - stat_type_map = list(stats["type_map"]) - missing_type = [ - i for i in target_type_map if i not in stat_type_map - ] - assert not missing_type, f"These type are not in stat file {file_path}: {missing_type}! Please change the stat file path!" - idx_map = [stat_type_map.index(i) for i in target_type_map] - if stats["sumr"].size: - sumr_tmp, suma_tmp, sumn_tmp, sumr2_tmp, suma2_tmp = ( - stats["sumr"][idx_map], - stats["suma"][idx_map], - stats["sumn"][idx_map], - stats["sumr2"][idx_map], - stats["suma2"][idx_map], - ) - else: - sumr_tmp, suma_tmp, sumn_tmp, sumr2_tmp, suma2_tmp = ( - [], - [], - [], - [], - [], - ) - sumr.append(sumr_tmp) - suma.append(suma_tmp) - sumn.append(sumn_tmp) - sumr2.append(sumr2_tmp) - suma2.append(suma2_tmp) - fitting_param["bias_atom_e"] = stats["bias_atom_e"][idx_map] - if id_bias_atom_e is None: - id_bias_atom_e = fitting_param["bias_atom_e"] - else: - assert ( - id_bias_atom_e == fitting_param["bias_atom_e"] - ).all(), "bias_atom_e in stat files are not consistent!" - self.descriptor.init_desc_stat(sumr, suma, sumn, sumr2, suma2) - else: # resuming for checkpoint; init model params from scratch - fitting_param["bias_atom_e"] = [0.0] * ntypes + sumr_tmp, suma_tmp, sumn_tmp, sumr2_tmp, suma2_tmp = ( + [], + [], + [], + [], + [], + ) + sumr.append(sumr_tmp) + suma.append(suma_tmp) + sumn.append(sumn_tmp) + sumr2.append(sumr2_tmp) + suma2.append(suma2_tmp) + + bias_atom_e = stats["bias_atom_e"][idx_map] + if id_bias_atom_e is None: + id_bias_atom_e = bias_atom_e + else: + assert ( + id_bias_atom_e == bias_atom_e + ).all(), "bias_atom_e in stat files are not consistent!" + + self.descriptor.init_desc_stat(sumr, suma, sumn, sumr2, suma2) + if self.fitting_net is not None: + self.fitting_net.init_energy_bias(bias_atom_e) diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 5e3cd87367..1308530341 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -169,7 +169,7 @@ def output_def(self) -> FittingOutputDef: def __setitem__(self, key, value): if key in ["bias_atom_e"]: - # correct bias_atom_e shape. user may provide stupid shape + value = value.view([self.ntypes, self.dim_out]) self.bias_atom_e = value elif key in ["fparam_avg"]: self.fparam_avg = value diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index b03aee7539..f9d9167c71 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -94,6 +94,11 @@ def share_params(self, base_class, shared_level, resume=False): else: raise NotImplementedError + def init_energy_bias(self, bias_atom_e): + self.bias_atom_e.copy_( + torch.tensor(bias_atom_e, device=DEVICE).view([self.ntypes, self.dim_out]) + ) + def change_energy_bias( self, config, model, old_type_map, new_type_map, bias_shift="delta", ntest=10 ): diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 02367f4aee..bf829c0762 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -179,7 +179,14 @@ def get_data_loader(_training_data, _validation_data, _training_params): ) def get_single_model(_model_params, _sampled): - model = get_model(deepcopy(_model_params), _sampled).to(DEVICE) + model = get_model(deepcopy(_model_params)).to(DEVICE) + if not model_params.get("resuming", False): + model.compute_or_load_stat( + type_map=_model_params["type_map"], + stat_file_dir=model_params.get("stat_file_dir", None), + stat_file_path=model_params.get("stat_file_path", None), + sampled=_sampled, + ) return model def get_lr(lr_params): diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 2b96925a51..d6621f7b4c 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -81,3 +81,12 @@ def to_torch_tensor( if prec is None: raise ValueError(f"unknown precision {xx.dtype}") return torch.tensor(xx, dtype=prec, device=DEVICE) + + +def dict_to_device(sample_dict): + for key in sample_dict: + if isinstance(sample_dict[key], list): + sample_dict[key] = [item.to(DEVICE) for item in sample_dict[key]] + else: + if sample_dict[key] is not None: + sample_dict[key] = sample_dict[key].to(DEVICE) diff --git a/source/tests/pt/test_autodiff.py b/source/tests/pt/test_autodiff.py index 8840fbdd4c..05b44955ea 100644 --- a/source/tests/pt/test_autodiff.py +++ b/source/tests/pt/test_autodiff.py @@ -16,7 +16,6 @@ from .test_permutation import ( eval_model, - make_sample, model_dpa1, model_dpa2, model_se_e2_a, @@ -133,33 +132,29 @@ def ff(bb): class TestEnergyModelSeAForce(unittest.TestCase, ForceTest): def setUp(self): model_params = copy.deepcopy(model_se_e2_a) - sampled = make_sample(model_params) self.type_split = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelSeAVirial(unittest.TestCase, VirialTest): def setUp(self): model_params = copy.deepcopy(model_se_e2_a) - sampled = make_sample(model_params) self.type_split = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1Force(unittest.TestCase, ForceTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1Virial(unittest.TestCase, VirialTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA2Force(unittest.TestCase, ForceTest): @@ -171,10 +166,9 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPAUniVirial(unittest.TestCase, VirialTest): @@ -186,7 +180,6 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) diff --git a/source/tests/pt/test_dp_atomic_model.py b/source/tests/pt/test_dp_atomic_model.py index 2960cb97cc..ef25e574d4 100644 --- a/source/tests/pt/test_dp_atomic_model.py +++ b/source/tests/pt/test_dp_atomic_model.py @@ -50,8 +50,7 @@ def test_self_consistency(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - md0 = DPAtomicModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md0 = DPAtomicModel(ds, ft, type_map=type_map).to(env.DEVICE) md1 = DPAtomicModel.deserialize(md0.serialize()).to(env.DEVICE) args = [ to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] @@ -107,6 +106,5 @@ def test_jit(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - md0 = DPAtomicModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md0 = DPAtomicModel(ds, ft, type_map=type_map).to(env.DEVICE) torch.jit.script(md0) diff --git a/source/tests/pt/test_dp_model.py b/source/tests/pt/test_dp_model.py index 79f65d26d6..b96150a9f2 100644 --- a/source/tests/pt/test_dp_model.py +++ b/source/tests/pt/test_dp_model.py @@ -55,8 +55,7 @@ def test_self_consistency(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md0 = DPModel(ds, ft, type_map=type_map).to(env.DEVICE) md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) args = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] ret0 = md0.forward_common(*args) @@ -205,8 +204,7 @@ def test_self_consistency(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md0 = DPModel(ds, ft, type_map=type_map).to(env.DEVICE) md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) args = [ to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] @@ -284,8 +282,7 @@ def test_jit(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + md0 = DPModel(ds, ft, type_map=type_map).to(env.DEVICE) torch.jit.script(md0) @@ -333,8 +330,7 @@ def setUp(self): distinguish_types=ds.distinguish_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - # TODO: dirty hack to avoid data stat!!! - self.md = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + self.md = DPModel(ds, ft, type_map=type_map).to(env.DEVICE) def test_nlist_eq(self): # n_nnei == nnei diff --git a/source/tests/pt/test_ener_fitting.py b/source/tests/pt/test_ener_fitting.py index cbddf34dd6..42aeeff16a 100644 --- a/source/tests/pt/test_ener_fitting.py +++ b/source/tests/pt/test_ener_fitting.py @@ -178,4 +178,6 @@ def test_get_set(self): "aparam_inv_std", ]: ifn0[ii] = torch.tensor(foo, dtype=dtype, device=env.DEVICE) - np.testing.assert_allclose(foo, ifn0[ii].detach().cpu().numpy()) + np.testing.assert_allclose( + foo, np.reshape(ifn0[ii].detach().cpu().numpy(), foo.shape) + ) diff --git a/source/tests/pt/test_force_grad.py b/source/tests/pt/test_force_grad.py index 1ea4321d21..0a4dc32d9f 100644 --- a/source/tests/pt/test_force_grad.py +++ b/source/tests/pt/test_force_grad.py @@ -19,15 +19,9 @@ from deepmd.pt.utils import ( env, ) -from deepmd.pt.utils.dataloader import ( - DpLoaderSet, -) from deepmd.pt.utils.dataset import ( DeepmdDataSystem, ) -from deepmd.pt.utils.stat import ( - make_stat_input, -) class CheckSymmetry(DeepmdDataSystem): @@ -75,18 +69,7 @@ def setUp(self): self.get_model() def get_model(self): - training_systems = self.config["training"]["training_data"]["systems"] - model_params = self.config["model"] - data_stat_nbatch = model_params.get("data_stat_nbatch", 10) - train_data = DpLoaderSet( - training_systems, - self.config["training"]["training_data"]["batch_size"], - model_params, - ) - sampled = make_stat_input( - train_data.systems, train_data.dataloaders, data_stat_nbatch - ) - self.model = get_model(self.config["model"], sampled).to(env.DEVICE) + self.model = get_model(self.config["model"]).to(env.DEVICE) def get_dataset(self, system_index=0, batch_index=0): systems = self.config["training"]["training_data"]["systems"] diff --git a/source/tests/pt/test_model.py b/source/tests/pt/test_model.py index bb99759d16..609ffe73e0 100644 --- a/source/tests/pt/test_model.py +++ b/source/tests/pt/test_model.py @@ -30,9 +30,6 @@ DEVICE, ) from deepmd.pt.utils.learning_rate import LearningRateExp as MyLRExp -from deepmd.pt.utils.stat import ( - make_stat_input, -) from deepmd.tf.common import ( data_requirement, expand_sys_str, @@ -282,9 +279,6 @@ def test_consistency(self): "type_map": self.type_map, }, ) - sampled = make_stat_input( - my_ds.systems, my_ds.dataloaders, self.data_stat_nbatch - ) my_model = get_model( model_params={ "descriptor": { @@ -298,8 +292,7 @@ def test_consistency(self): "fitting_net": {"neuron": self.n_neuron, "distinguish_types": True}, "data_stat_nbatch": self.data_stat_nbatch, "type_map": self.type_map, - }, - sampled=sampled, + } ) my_model.to(DEVICE) my_lr = MyLRExp(self.start_lr, self.stop_lr, self.decay_steps, self.stop_steps) diff --git a/source/tests/pt/test_permutation.py b/source/tests/pt/test_permutation.py index b9724bb2af..d1904f939f 100644 --- a/source/tests/pt/test_permutation.py +++ b/source/tests/pt/test_permutation.py @@ -237,17 +237,15 @@ def test( class TestEnergyModelSeA(unittest.TestCase, PermutationTest): def setUp(self): model_params = copy.deepcopy(model_se_e2_a) - sampled = make_sample(model_params) self.type_split = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1(unittest.TestCase, PermutationTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA2(unittest.TestCase, PermutationTest): @@ -259,10 +257,9 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestForceModelDPA2(unittest.TestCase, PermutationTest): @@ -274,21 +271,19 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) model_params["fitting_net"]["type"] = "direct_force_ener" self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") class TestEnergyModelHybrid(unittest.TestCase, PermutationTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") @@ -296,10 +291,9 @@ class TestForceModelHybrid(unittest.TestCase, PermutationTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) model_params["fitting_net"]["type"] = "direct_force_ener" - sampled = make_sample(model_params) self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) # class TestEnergyFoo(unittest.TestCase): diff --git a/source/tests/pt/test_permutation_denoise.py b/source/tests/pt/test_permutation_denoise.py index 6dd61ab7e4..eff589e931 100644 --- a/source/tests/pt/test_permutation_denoise.py +++ b/source/tests/pt/test_permutation_denoise.py @@ -15,7 +15,6 @@ ) from .test_permutation import ( # model_dpau, - make_sample, model_dpa1, model_dpa2, model_hybrid, @@ -70,9 +69,8 @@ def test( class TestDenoiseModelDPA1(unittest.TestCase, PermutationDenoiseTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("support of the denoise is temporally disabled") @@ -85,19 +83,17 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) # @unittest.skip("hybrid not supported at the moment") # class TestDenoiseModelHybrid(unittest.TestCase, TestPermutationDenoise): # def setUp(self): # model_params = copy.deepcopy(model_hybrid_denoise) -# sampled = make_sample(model_params) # self.type_split = True -# self.model = get_model(model_params, sampled).to(env.DEVICE) +# self.model = get_model(model_params).to(env.DEVICE) if __name__ == "__main__": diff --git a/source/tests/pt/test_rot.py b/source/tests/pt/test_rot.py index 7222fd6f69..437a892cc3 100644 --- a/source/tests/pt/test_rot.py +++ b/source/tests/pt/test_rot.py @@ -15,7 +15,6 @@ ) from .test_permutation import ( # model_dpau, - make_sample, model_dpa1, model_dpa2, model_hybrid, @@ -112,17 +111,15 @@ def test( class TestEnergyModelSeA(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_se_e2_a) - sampled = make_sample(model_params) self.type_split = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA2(unittest.TestCase, RotTest): @@ -134,10 +131,9 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestForceModelDPA2(unittest.TestCase, RotTest): @@ -149,21 +145,19 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) model_params["fitting_net"]["type"] = "direct_force_ener" self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") class TestEnergyModelHybrid(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") @@ -171,10 +165,9 @@ class TestForceModelHybrid(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) model_params["fitting_net"]["type"] = "direct_force_ener" - sampled = make_sample(model_params) self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) if __name__ == "__main__": diff --git a/source/tests/pt/test_rot_denoise.py b/source/tests/pt/test_rot_denoise.py index 2cbfd8fd38..35d76437ca 100644 --- a/source/tests/pt/test_rot_denoise.py +++ b/source/tests/pt/test_rot_denoise.py @@ -15,7 +15,6 @@ ) from .test_permutation_denoise import ( - make_sample, model_dpa1, model_dpa2, ) @@ -101,9 +100,8 @@ def test( class TestDenoiseModelDPA1(unittest.TestCase, RotDenoiseTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("support of the denoise is temporally disabled") @@ -116,10 +114,9 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) # @unittest.skip("hybrid not supported at the moment") diff --git a/source/tests/pt/test_rotation.py b/source/tests/pt/test_rotation.py index a62e04eb89..5314959673 100644 --- a/source/tests/pt/test_rotation.py +++ b/source/tests/pt/test_rotation.py @@ -21,15 +21,9 @@ from deepmd.pt.utils import ( env, ) -from deepmd.pt.utils.dataloader import ( - DpLoaderSet, -) from deepmd.pt.utils.dataset import ( DeepmdDataSystem, ) -from deepmd.pt.utils.stat import ( - make_stat_input, -) class CheckSymmetry(DeepmdDataSystem): @@ -82,18 +76,7 @@ def setUp(self): self.get_model() def get_model(self): - training_systems = self.config["training"]["training_data"]["systems"] - model_params = self.config["model"] - data_stat_nbatch = model_params.get("data_stat_nbatch", 10) - train_data = DpLoaderSet( - training_systems, - self.config["training"]["training_data"]["batch_size"], - model_params, - ) - sampled = make_stat_input( - train_data.systems, train_data.dataloaders, data_stat_nbatch - ) - self.model = get_model(self.config["model"], sampled).to(env.DEVICE) + self.model = get_model(self.config["model"]).to(env.DEVICE) def get_dataset(self, system_index=0, batch_index=0): systems = self.config["training"]["training_data"]["systems"] diff --git a/source/tests/pt/test_saveload_dpa1.py b/source/tests/pt/test_saveload_dpa1.py index 1b4c41a204..64229b8e9e 100644 --- a/source/tests/pt/test_saveload_dpa1.py +++ b/source/tests/pt/test_saveload_dpa1.py @@ -109,14 +109,13 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"): def create_wrapper(self, read: bool): model_config = copy.deepcopy(self.config["model"]) - sampled = copy.deepcopy(self.sampled) model_config["resuming"] = read model_config["stat_file_dir"] = "stat_files" model_config["stat_file"] = "stat.npz" model_config["stat_file_path"] = os.path.join( model_config["stat_file_dir"], model_config["stat_file"] ) - model = get_model(model_config, sampled).to(env.DEVICE) + model = get_model(model_config).to(env.DEVICE) return ModelWrapper(model, self.loss) def get_data(self): diff --git a/source/tests/pt/test_saveload_se_e2_a.py b/source/tests/pt/test_saveload_se_e2_a.py index 7f8364a16f..0632e30b5b 100644 --- a/source/tests/pt/test_saveload_se_e2_a.py +++ b/source/tests/pt/test_saveload_se_e2_a.py @@ -109,8 +109,7 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"): def create_wrapper(self): model_config = copy.deepcopy(self.config["model"]) - sampled = copy.deepcopy(self.sampled) - model = get_model(model_config, sampled).to(env.DEVICE) + model = get_model(model_config).to(env.DEVICE) return ModelWrapper(model, self.loss) def get_data(self): diff --git a/source/tests/pt/test_smooth.py b/source/tests/pt/test_smooth.py index 2e3bf61d10..2373d486f6 100644 --- a/source/tests/pt/test_smooth.py +++ b/source/tests/pt/test_smooth.py @@ -15,7 +15,6 @@ ) from .test_permutation import ( # model_dpau, - make_sample, model_dpa1, model_dpa2, model_hybrid, @@ -123,9 +122,8 @@ def compare(ret0, ret1): class TestEnergyModelSeA(unittest.TestCase, SmoothTest): def setUp(self): model_params = copy.deepcopy(model_se_e2_a) - sampled = make_sample(model_params) self.type_split = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None @@ -133,9 +131,8 @@ def setUp(self): class TestEnergyModelDPA1(unittest.TestCase, SmoothTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) # less degree of smoothness, # error can be systematically removed by reducing epsilon self.epsilon = 1e-5 @@ -158,9 +155,8 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = 1e-5, 1e-4 @@ -175,10 +171,9 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None @@ -193,10 +188,9 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None @@ -204,17 +198,15 @@ def setUp(self): class TestEnergyModelHybrid(unittest.TestCase, SmoothTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None # class TestEnergyFoo(unittest.TestCase): # def test(self): # model_params = model_dpau -# sampled = make_sample(model_params) -# self.model = EnergyModelDPAUni(model_params, sampled).to(env.DEVICE) +# self.model = EnergyModelDPAUni(model_params).to(env.DEVICE) # natoms = 5 # cell = torch.rand([3, 3], dtype=dtype) diff --git a/source/tests/pt/test_smooth_denoise.py b/source/tests/pt/test_smooth_denoise.py index de89f8dccc..777d288f3c 100644 --- a/source/tests/pt/test_smooth_denoise.py +++ b/source/tests/pt/test_smooth_denoise.py @@ -15,7 +15,6 @@ ) from .test_permutation_denoise import ( - make_sample, model_dpa2, ) @@ -106,12 +105,11 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) model_params["descriptor"]["sel"] = 8 model_params["descriptor"]["rcut_smth"] = 3.5 self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None self.epsilon = 1e-7 self.aprec = 1e-5 @@ -127,11 +125,10 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) # model_params["descriptor"]["combine_grrg"] = True self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) self.epsilon, self.aprec = None, None self.epsilon = 1e-7 self.aprec = 1e-5 @@ -141,9 +138,8 @@ def setUp(self): # class TestDenoiseModelHybrid(unittest.TestCase, TestSmoothDenoise): # def setUp(self): # model_params = copy.deepcopy(model_hybrid_denoise) -# sampled = make_sample(model_params) # self.type_split = True -# self.model = get_model(model_params, sampled).to(env.DEVICE) +# self.model = get_model(model_params).to(env.DEVICE) # self.epsilon, self.aprec = None, None # self.epsilon = 1e-7 # self.aprec = 1e-5 diff --git a/source/tests/pt/test_trans.py b/source/tests/pt/test_trans.py index e5d379b9ff..8d8bbe11ff 100644 --- a/source/tests/pt/test_trans.py +++ b/source/tests/pt/test_trans.py @@ -15,7 +15,6 @@ ) from .test_permutation import ( # model_dpau, - make_sample, model_dpa1, model_dpa2, model_hybrid, @@ -68,17 +67,15 @@ def test( class TestEnergyModelSeA(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_se_e2_a) - sampled = make_sample(model_params) self.type_split = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA2(unittest.TestCase, TransTest): @@ -90,10 +87,9 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) class TestForceModelDPA2(unittest.TestCase, TransTest): @@ -105,21 +101,19 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) model_params["fitting_net"]["type"] = "direct_force_ener" self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") class TestEnergyModelHybrid(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") @@ -127,10 +121,9 @@ class TestForceModelHybrid(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) model_params["fitting_net"]["type"] = "direct_force_ener" - sampled = make_sample(model_params) self.type_split = True self.test_virial = False - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) if __name__ == "__main__": diff --git a/source/tests/pt/test_trans_denoise.py b/source/tests/pt/test_trans_denoise.py index 88b926a3ae..9ba93a244a 100644 --- a/source/tests/pt/test_trans_denoise.py +++ b/source/tests/pt/test_trans_denoise.py @@ -15,7 +15,6 @@ ) from .test_permutation_denoise import ( - make_sample, model_dpa1, model_dpa2, model_hybrid, @@ -60,9 +59,8 @@ def test( class TestDenoiseModelDPA1(unittest.TestCase, TransDenoiseTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("support of the denoise is temporally disabled") @@ -75,19 +73,17 @@ def setUp(self): model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ "repinit_nsel" ] - sampled = make_sample(model_params_sample) model_params = copy.deepcopy(model_dpa2) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) @unittest.skip("hybrid not supported at the moment") class TestDenoiseModelHybrid(unittest.TestCase, TransDenoiseTest): def setUp(self): model_params = copy.deepcopy(model_hybrid) - sampled = make_sample(model_params) self.type_split = True - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) if __name__ == "__main__": diff --git a/source/tests/pt/test_unused_params.py b/source/tests/pt/test_unused_params.py index a924979466..f69d8ac835 100644 --- a/source/tests/pt/test_unused_params.py +++ b/source/tests/pt/test_unused_params.py @@ -15,7 +15,6 @@ ) from .test_permutation import ( - make_sample, model_dpa2, ) @@ -57,8 +56,7 @@ def test_unused(self): self._test_unused(model) def _test_unused(self, model_params): - sampled = make_sample(model_params) - self.model = get_model(model_params, sampled).to(env.DEVICE) + self.model = get_model(model_params).to(env.DEVICE) natoms = 5 cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE) cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE)