diff --git a/deepmd/dpmodel/descriptor/hybrid.py b/deepmd/dpmodel/descriptor/hybrid.py index d2620fdcf7..46f2616b84 100644 --- a/deepmd/dpmodel/descriptor/hybrid.py +++ b/deepmd/dpmodel/descriptor/hybrid.py @@ -127,6 +127,14 @@ def mixed_types(self): """ return any(descrpt.mixed_types() for descrpt in self.descrpt_list) + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + raise NotImplementedError + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" for descrpt in self.descrpt_list: diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index 69f0da787f..940bd0cd27 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -4,8 +4,10 @@ abstractmethod, ) from typing import ( + Callable, List, Optional, + Union, ) from deepmd.common import ( @@ -84,8 +86,19 @@ def mixed_types(self) -> bool: """ pass + @abstractmethod + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + pass + def compute_input_stats( - self, merged: List[dict], path: Optional[DPPath] = None + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, ): """Update mean and stddev for descriptor elements.""" raise NotImplementedError diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 5e72653f1d..f6b1c5677e 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -243,6 +243,14 @@ def mixed_types(self): """ return False + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + raise NotImplementedError + def get_ntypes(self) -> int: """Returns the number of element types.""" return self.ntypes diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index a5dcfb16dd..2dbf495d14 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -203,6 +203,14 @@ def mixed_types(self): """ return False + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + raise NotImplementedError + def get_ntypes(self) -> int: """Returns the number of element types.""" return self.ntypes diff --git a/deepmd/dpmodel/model/dp_model.py b/deepmd/dpmodel/model/dp_model.py index ef7866a6dd..15f9027d4c 100644 --- a/deepmd/dpmodel/model/dp_model.py +++ b/deepmd/dpmodel/model/dp_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later + from deepmd.dpmodel.atomic_model import ( DPAtomicModel, ) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index c4b5a4cf44..023bc5305e 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -53,9 +53,6 @@ from deepmd.pt.utils.multi_task import ( preprocess_shared_params, ) -from deepmd.pt.utils.stat import ( - make_stat_input, -) from deepmd.utils.argcheck import ( normalize, ) @@ -104,36 +101,23 @@ def get_trainer( config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None) def prepare_trainer_input_single( - model_params_single, data_dict_single, loss_dict_single, suffix="" + model_params_single, data_dict_single, loss_dict_single, suffix="", rank=0 ): training_dataset_params = data_dict_single["training_data"] type_split = False if model_params_single["descriptor"]["type"] in ["se_e2_a"]: type_split = True - validation_dataset_params = data_dict_single["validation_data"] + validation_dataset_params = data_dict_single.get("validation_data", None) + validation_systems = ( + validation_dataset_params["systems"] if validation_dataset_params else None + ) training_systems = training_dataset_params["systems"] - validation_systems = validation_dataset_params["systems"] - - # noise params - noise_settings = None - if loss_dict_single.get("type", "ener") == "denoise": - noise_settings = { - "noise_type": loss_dict_single.pop("noise_type", "uniform"), - "noise": loss_dict_single.pop("noise", 1.0), - "noise_mode": loss_dict_single.pop("noise_mode", "fix_num"), - "mask_num": loss_dict_single.pop("mask_num", 8), - "mask_prob": loss_dict_single.pop("mask_prob", 0.15), - "same_mask": loss_dict_single.pop("same_mask", False), - "mask_coord": loss_dict_single.pop("mask_coord", False), - "mask_type": loss_dict_single.pop("mask_type", False), - "max_fail_num": loss_dict_single.pop("max_fail_num", 10), - "mask_type_idx": len(model_params_single["type_map"]) - 1, - } - # noise_settings = None # stat files stat_file_path_single = data_dict_single.get("stat_file", None) - if stat_file_path_single is not None: + if rank != 0: + stat_file_path_single = None + elif stat_file_path_single is not None: if Path(stat_file_path_single).is_dir(): raise ValueError( f"stat_file should be a file, not a directory: {stat_file_path_single}" @@ -144,10 +128,14 @@ def prepare_trainer_input_single( stat_file_path_single = DPPath(stat_file_path_single, "a") # validation and training data - validation_data_single = DpLoaderSet( - validation_systems, - validation_dataset_params["batch_size"], - model_params_single, + validation_data_single = ( + DpLoaderSet( + validation_systems, + validation_dataset_params["batch_size"], + model_params_single, + ) + if validation_systems + else None ) if ckpt or finetune_model: train_data_single = DpLoaderSet( @@ -155,60 +143,48 @@ def prepare_trainer_input_single( training_dataset_params["batch_size"], model_params_single, ) - sampled_single = None else: train_data_single = DpLoaderSet( training_systems, training_dataset_params["batch_size"], model_params_single, ) - data_stat_nbatch = model_params_single.get("data_stat_nbatch", 10) - sampled_single = make_stat_input( - train_data_single.systems, - train_data_single.dataloaders, - data_stat_nbatch, - ) - if noise_settings is not None: - train_data_single = DpLoaderSet( - training_systems, - training_dataset_params["batch_size"], - model_params_single, - ) return ( train_data_single, validation_data_single, - sampled_single, stat_file_path_single, ) + rank = dist.get_rank() if dist.is_initialized() else 0 if not multi_task: ( train_data, validation_data, - sampled, stat_file_path, ) = prepare_trainer_input_single( - config["model"], config["training"], config["loss"] + config["model"], + config["training"], + config["loss"], + rank=rank, ) else: - train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {} + train_data, validation_data, stat_file_path = {}, {}, {} for model_key in config["model"]["model_dict"]: ( train_data[model_key], validation_data[model_key], - sampled[model_key], stat_file_path[model_key], ) = prepare_trainer_input_single( config["model"]["model_dict"][model_key], config["training"]["data_dict"][model_key], config["loss_dict"][model_key], suffix=f"_{model_key}", + rank=rank, ) trainer = training.Trainer( config, train_data, - sampled=sampled, stat_file_path=stat_file_path, validation_data=validation_data, init_model=init_model, diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index 4ed765cf69..2834733112 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, +) + import torch import torch.nn.functional as F @@ -11,6 +15,9 @@ from deepmd.pt.utils.env import ( GLOBAL_PT_FLOAT_PRECISION, ) +from deepmd.utils.data import ( + DataRequirementItem, +) class EnergyStdLoss(TaskLoss): @@ -23,16 +30,57 @@ def __init__( limit_pref_f=0.0, start_pref_v=0.0, limit_pref_v=0.0, + start_pref_ae: float = 0.0, + limit_pref_ae: float = 0.0, + start_pref_pf: float = 0.0, + limit_pref_pf: float = 0.0, use_l1_all: bool = False, inference=False, **kwargs, ): - """Construct a layer to compute loss on energy, force and virial.""" + r"""Construct a layer to compute loss on energy, force and virial. + + Parameters + ---------- + starter_learning_rate : float + The learning rate at the start of the training. + start_pref_e : float + The prefactor of energy loss at the start of the training. + limit_pref_e : float + The prefactor of energy loss at the end of the training. + start_pref_f : float + The prefactor of force loss at the start of the training. + limit_pref_f : float + The prefactor of force loss at the end of the training. + start_pref_v : float + The prefactor of virial loss at the start of the training. + limit_pref_v : float + The prefactor of virial loss at the end of the training. + start_pref_ae : float + The prefactor of atomic energy loss at the start of the training. + limit_pref_ae : float + The prefactor of atomic energy loss at the end of the training. + start_pref_pf : float + The prefactor of atomic prefactor force loss at the start of the training. + limit_pref_pf : float + The prefactor of atomic prefactor force loss at the end of the training. + use_l1_all : bool + Whether to use L1 loss, if False (default), it will use L2 loss. + inference : bool + If true, it will output all losses found in output, ignoring the pre-factors. + **kwargs + Other keyword arguments. + """ super().__init__() self.starter_learning_rate = starter_learning_rate self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference self.has_f = (start_pref_f != 0.0 and limit_pref_f != 0.0) or inference self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference + + # TODO need support for atomic energy and atomic pref + self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference + self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference + self.start_pref_e = start_pref_e self.limit_pref_e = limit_pref_e self.start_pref_f = start_pref_f @@ -153,3 +201,60 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False): if not self.inference: more_loss["rmse"] = torch.sqrt(loss.detach()) return loss, more_loss + + @property + def label_requirement(self) -> List[DataRequirementItem]: + """Return data label requirements needed for this loss calculation.""" + label_requirement = [] + if self.has_e: + label_requirement.append( + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + must=False, + high_prec=True, + ) + ) + if self.has_f: + label_requirement.append( + DataRequirementItem( + "force", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ) + ) + if self.has_v: + label_requirement.append( + DataRequirementItem( + "virial", + ndof=9, + atomic=False, + must=False, + high_prec=False, + ) + ) + if self.has_ae: + label_requirement.append( + DataRequirementItem( + "atom_ener", + ndof=1, + atomic=True, + must=False, + high_prec=False, + ) + ) + if self.has_pf: + label_requirement.append( + DataRequirementItem( + "atom_pref", + ndof=1, + atomic=True, + must=False, + high_prec=False, + repeat=3, + ) + ) + return label_requirement diff --git a/deepmd/pt/loss/loss.py b/deepmd/pt/loss/loss.py index 9f2c3a7ed7..925ff8f4ef 100644 --- a/deepmd/pt/loss/loss.py +++ b/deepmd/pt/loss/loss.py @@ -1,8 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractmethod, +) +from typing import ( + List, +) + import torch +from deepmd.utils.data import ( + DataRequirementItem, +) + -class TaskLoss(torch.nn.Module): +class TaskLoss(torch.nn.Module, ABC): def __init__(self, **kwargs): """Construct loss.""" super().__init__() @@ -10,3 +22,9 @@ def __init__(self, **kwargs): def forward(self, model_pred, label, natoms, learning_rate): """Return loss .""" raise NotImplementedError + + @property + @abstractmethod + def label_requirement(self) -> List[DataRequirementItem]: + """Return data label requirements needed for this loss calculation.""" + pass diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 63e91ff428..7f6c3076d8 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -18,9 +18,6 @@ from deepmd.pt.model.task.base_fitting import ( BaseFitting, ) -from deepmd.pt.utils.utils import ( - dict_to_device, -) from deepmd.utils.path import ( DPPath, ) @@ -185,7 +182,7 @@ def forward_atomic( def compute_or_load_stat( self, - sampled, + sampled_func, stat_file_path: Optional[DPPath] = None, ): """ @@ -198,8 +195,8 @@ def compute_or_load_stat( Parameters ---------- - sampled - The sampled data frames from different data systems. + sampled_func + The lazy sampled function to get data frames from different data systems. stat_file_path The dictionary of paths to the statistics files. """ @@ -207,13 +204,9 @@ def compute_or_load_stat( # descriptors and fitting net with different type_map # should not share the same parameters stat_file_path /= " ".join(self.type_map) - for data_sys in sampled: - dict_to_device(data_sys) - if sampled is None: - sampled = [] - self.descriptor.compute_input_stats(sampled, stat_file_path) + self.descriptor.compute_input_stats(sampled_func, stat_file_path) if self.fitting_net is not None: - self.fitting_net.compute_output_stats(sampled, stat_file_path) + self.fitting_net.compute_output_stats(sampled_func, stat_file_path) @torch.jit.export def get_dim_fparam(self) -> int: diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py index 72f734de04..325cf29e42 100644 --- a/deepmd/pt/model/descriptor/__init__.py +++ b/deepmd/pt/model/descriptor/__init__.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .base_descriptor import ( + BaseDescriptor, +) from .descriptor import ( DescriptorBlock, make_default_type_embedding, @@ -32,6 +35,7 @@ ) __all__ = [ + "BaseDescriptor", "DescriptorBlock", "make_default_type_embedding", "DescrptBlockSeA", diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 964cdb01eb..24c1ef4dab 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -5,9 +5,11 @@ abstractmethod, ) from typing import ( + Callable, Dict, List, Optional, + Union, ) import torch @@ -86,8 +88,27 @@ def get_dim_emb(self) -> int: """Returns the embedding dimension.""" pass - def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): - """Update mean and stddev for DescriptorBlock elements.""" + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ raise NotImplementedError def get_stats(self) -> Dict[str, StatItem]: @@ -95,6 +116,11 @@ def get_stats(self) -> Dict[str, StatItem]: raise NotImplementedError def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ assert ( self.__class__ == base_class.__class__ ), "Only descriptors of the same type can share params!" diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 0245179d8b..224a24d60e 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Callable, List, Optional, + Union, ) import torch @@ -145,6 +147,29 @@ def mixed_types(self) -> bool: """ return self.se_atten.mixed_types() + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + # For DPA1 descriptors, the user-defined share-level + # shared_level: 0 + # share all parameters in both type_embedding and se_atten + if shared_level == 0: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + self.se_atten.share_params(base_class.se_atten, 0, resume=resume) + # shared_level: 1 + # share all parameters in type_embedding + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + # Other shared levels + else: + raise NotImplementedError + @property def dim_out(self): return self.get_dim_out() @@ -153,7 +178,27 @@ def dim_out(self): def dim_emb(self): return self.get_dim_emb() - def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ return self.se_atten.compute_input_stats(merged, path) def serialize(self) -> dict: diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 20a7c74cda..dcb381d53a 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Callable, List, Optional, + Union, ) import torch @@ -289,6 +291,46 @@ def mixed_types(self) -> bool: """ return True + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + # For DPA2 descriptors, the user-defined share-level + # shared_level: 0 + # share all parameters in type_embedding, repinit and repformers + if shared_level == 0: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + self.repinit.share_params(base_class.repinit, 0, resume=resume) + self._modules["g1_shape_tranform"] = base_class._modules[ + "g1_shape_tranform" + ] + self.repformers.share_params(base_class.repformers, 0, resume=resume) + # shared_level: 1 + # share all parameters in type_embedding and repinit + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + self.repinit.share_params(base_class.repinit, 0, resume=resume) + # shared_level: 2 + # share all parameters in type_embedding and repformers + elif shared_level == 2: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + self._modules["g1_shape_tranform"] = base_class._modules[ + "g1_shape_tranform" + ] + self.repformers.share_params(base_class.repformers, 0, resume=resume) + # shared_level: 3 + # share all parameters in type_embedding + elif shared_level == 3: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + # Other shared levels + else: + raise NotImplementedError + @property def dim_out(self): return self.get_dim_out() @@ -298,16 +340,29 @@ def dim_emb(self): """Returns the embedding dimension g2.""" return self.get_dim_emb() - def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ for ii, descrpt in enumerate([self.repinit, self.repformers]): - merged_tmp = [ - { - key: item[key] if not isinstance(item[key], list) else item[key][ii] - for key in item - } - for item in merged - ] - descrpt.compute_input_stats(merged_tmp, path) + descrpt.compute_input_stats(merged, path) def serialize(self) -> dict: """Serialize the obj to dict.""" diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index 5aa83ef534..b53adca462 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, + Callable, Dict, List, Optional, @@ -139,6 +140,23 @@ def mixed_types(self): """ return any(descrpt.mixed_types() for descrpt in self.descrpt_list) + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + if shared_level == 0: + for ii, des in enumerate(self.descrpt_list): + self.descrpt_list[ii].share_params( + base_class.descrpt_list[ii], shared_level, resume=resume + ) + else: + raise NotImplementedError + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" for descrpt in self.descrpt_list: @@ -383,6 +401,11 @@ def dim_emb(self): raise RuntimeError def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ assert ( self.__class__ == base_class.__class__ ), "Only descriptors of the same type can share params!" @@ -391,22 +414,33 @@ def share_params(self, base_class, shared_level, resume=False): self.descriptor_list[ii].share_params( base_class.descriptor_list[ii], shared_level, resume=resume ) - if self.hybrid_mode == "sequential": - self.sequential_transform = base_class.sequential_transform else: raise NotImplementedError - def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): - """Update mean and stddev for descriptor elements.""" + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ for ii, descrpt in enumerate(self.descriptor_list): - merged_tmp = [ - { - key: item[key] if not isinstance(item[key], list) else item[key][ii] - for key in item - } - for item in merged - ] - descrpt.compute_input_stats(merged_tmp, path) + # need support for hybrid descriptors + descrpt.compute_input_stats(merged, path) def forward( self, diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 2425139e16..3e8bf72f77 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Callable, Dict, List, Optional, + Union, ) import torch @@ -278,12 +280,39 @@ def forward( return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw - def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): - """Update mean and stddev for descriptor elements.""" + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ env_mat_stat = EnvMatStatSe(self) if path is not None: path = path / env_mat_stat.get_hash() - env_mat_stat.load_or_compute_stats(merged, path) + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() if not self.set_davg_zero: diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index fc2cf60531..d836b48992 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import itertools from typing import ( + Callable, ClassVar, Dict, List, Optional, Tuple, + Union, ) import numpy as np @@ -127,13 +129,50 @@ def mixed_types(self): """ return self.sea.mixed_types() + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + # For SeA descriptors, the user-defined share-level + # shared_level: 0 + # share all parameters in sea + if shared_level == 0: + self.sea.share_params(base_class.sea, 0, resume=resume) + # Other shared levels + else: + raise NotImplementedError + @property def dim_out(self): """Returns the output dimension of this descriptor.""" return self.sea.dim_out - def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): - """Update mean and stddev for descriptor elements.""" + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ return self.sea.compute_input_stats(merged, path) def reinit_exclude( @@ -411,12 +450,39 @@ def __getitem__(self, key): else: raise KeyError(key) - def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): - """Update mean and stddev for descriptor elements.""" + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ env_mat_stat = EnvMatStatSe(self) if path is not None: path = path / env_mat_stat.get_hash() - env_mat_stat.load_or_compute_stats(merged, path) + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() if not self.set_davg_zero: diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index c815cda013..db9202c7fc 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Callable, Dict, List, Optional, + Union, ) import numpy as np @@ -200,12 +202,39 @@ def dim_emb(self): """Returns the output dimension of embedding.""" return self.get_dim_emb() - def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): - """Update mean and stddev for descriptor elements.""" + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ env_mat_stat = EnvMatStatSe(self) if path is not None: path = path / env_mat_stat.get_hash() - env_mat_stat.load_or_compute_stats(merged, path) + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() if not self.set_davg_zero: diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index 16721fbe5e..27e459d861 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Callable, Dict, List, Optional, Tuple, + Union, ) import numpy as np @@ -151,12 +153,72 @@ def mixed_types(self) -> bool: """ return False - def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): - """Update mean and stddev for descriptor elements.""" + def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + # For SeR descriptors, the user-defined share-level + # shared_level: 0 + if shared_level == 0: + # link buffers + if hasattr(self, "mean") and not resume: + # in case of change params during resume + base_env = EnvMatStatSe(base_class) + base_env.stats = base_class.stats + for kk in base_class.get_stats(): + base_env.stats[kk] += self.get_stats()[kk] + mean, stddev = base_env() + if not base_class.set_davg_zero: + base_class.mean.copy_(torch.tensor(mean, device=env.DEVICE)) + base_class.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + self.mean = base_class.mean + self.stddev = base_class.stddev + # self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model + # the following will successfully link all the params except buffers + for item in self._modules: + self._modules[item] = base_class._modules[item] + # Other shared levels + else: + raise NotImplementedError + + def compute_input_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ env_mat_stat = EnvMatStatSe(self) if path is not None: path = path / env_mat_stat.get_hash() - env_mat_stat.load_or_compute_stats(merged, path) + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() if not self.set_davg_zero: diff --git a/deepmd/pt/model/network/network.py b/deepmd/pt/model/network/network.py index 9ef7b3366a..10d0364c9b 100644 --- a/deepmd/pt/model/network/network.py +++ b/deepmd/pt/model/network/network.py @@ -575,6 +575,11 @@ def forward(self, atype): return self.embedding(atype) def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ assert ( self.__class__ == base_class.__class__ ), "Only TypeEmbedNet of the same type can share params!" diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 9df3a5fb32..7d2dd221db 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging from typing import ( + Callable, List, Optional, + Union, ) import torch @@ -20,6 +22,9 @@ from deepmd.pt.utils.env import ( DEFAULT_PRECISION, ) +from deepmd.utils.path import ( + DPPath, +) log = logging.getLogger(__name__) @@ -67,7 +72,6 @@ class DipoleFittingNet(GeneralFitting): def __init__( self, - var_name: str, ntypes: int, dim_descrpt: int, embedding_width: int, @@ -89,7 +93,7 @@ def __init__( self.r_differentiable = r_differentiable self.c_differentiable = c_differentiable super().__init__( - var_name=var_name, + var_name=kwargs.pop("var_name", "dipole"), ntypes=ntypes, dim_descrpt=dim_descrpt, neuron=neuron, @@ -132,6 +136,29 @@ def output_def(self) -> FittingOutputDef: ] ) + def compute_output_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + stat_file_path: Optional[DPPath] = None, + ): + """ + Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + stat_file_path : Optional[DPPath] + The path to the stat file. + + """ + raise NotImplementedError + def forward( self, descriptor: torch.Tensor, diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index ff7ae6f8ec..29ed5acaad 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -2,9 +2,11 @@ import copy import logging from typing import ( + Callable, List, Optional, Tuple, + Union, ) import numpy as np @@ -138,18 +140,43 @@ def serialize(self) -> dict: data["atom_ener"] = self.atom_ener return data - def compute_output_stats(self, merged, stat_file_path: Optional[DPPath] = None): - energy = [item[self.var_name] for item in merged] - data_mixed_type = "real_natoms_vec" in merged[0] - if data_mixed_type: - input_natoms = [item["real_natoms_vec"] for item in merged] - else: - input_natoms = [item["natoms"] for item in merged] + def compute_output_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + stat_file_path: Optional[DPPath] = None, + ): + """ + Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + stat_file_path : Optional[DPPath] + The path to the stat file. + + """ if stat_file_path is not None: stat_file_path = stat_file_path / "bias_atom_e" if stat_file_path is not None and stat_file_path.is_file(): bias_atom_e = stat_file_path.load_numpy() else: + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + energy = [item["energy"] for item in sampled] + data_mixed_type = "real_natoms_vec" in sampled[0] + if data_mixed_type: + input_natoms = [item["real_natoms_vec"] for item in sampled] + else: + input_natoms = [item["natoms"] for item in sampled] # shape: (nframes, ndim) merged_energy = to_numpy_array(torch.cat(energy)) # shape: (nframes, ntypes) @@ -320,7 +347,6 @@ def __init__( self.filter_layers = torch.nn.ModuleList(filter_layers) if "seed" in kwargs: - log.info("Set seed to %d in fitting net.", kwargs["seed"]) torch.manual_seed(kwargs["seed"]) def output_def(self): diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 8e8338210f..47535580db 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -62,6 +62,11 @@ def __new__(cls, *args, **kwargs): return super().__new__(cls) def share_params(self, base_class, shared_level, resume=False): + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ assert ( self.__class__ == base_class.__class__ ), "Only fitting nets of the same type can share params!" @@ -77,18 +82,6 @@ def share_params(self, base_class, shared_level, resume=False): # the following will successfully link all the params except buffers, which need manually link. for item in self._modules: self._modules[item] = base_class._modules[item] - elif shared_level == 2: - # share all the layers before final layer - # the following will successfully link all the params except buffers, which need manually link. - self._modules["filter_layers"][0].deep_layers = base_class._modules[ - "filter_layers" - ][0].deep_layers - elif shared_level == 3: - # share the first layers - # the following will successfully link all the params except buffers, which need manually link. - self._modules["filter_layers"][0].deep_layers[0] = base_class._modules[ - "filter_layers" - ][0].deep_layers[0] else: raise NotImplementedError @@ -354,7 +347,6 @@ def __init__( self.filter_layers_old = None if seed is not None: - log.info("Set seed to %d in fitting net.", seed) torch.manual_seed(seed) def reinit_exclude( diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 1bc4798c48..9483d1eb4a 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging from typing import ( + Callable, List, Optional, Union, @@ -24,6 +25,9 @@ from deepmd.pt.utils.utils import ( to_numpy_array, ) +from deepmd.utils.path import ( + DPPath, +) log = logging.getLogger(__name__) @@ -72,7 +76,6 @@ class PolarFittingNet(GeneralFitting): def __init__( self, - var_name: str, ntypes: int, dim_descrpt: int, embedding_width: int, @@ -112,7 +115,7 @@ def __init__( ).view(ntypes, 1) self.shift_diag = shift_diag super().__init__( - var_name=var_name, + var_name=kwargs.pop("var_name", "polar"), ntypes=ntypes, dim_descrpt=dim_descrpt, neuron=neuron, @@ -160,6 +163,29 @@ def output_def(self) -> FittingOutputDef: ] ) + def compute_output_stats( + self, + merged: Union[Callable[[], List[dict]], List[dict]], + stat_file_path: Optional[DPPath] = None, + ): + """ + Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + + Parameters + ---------- + merged : Union[Callable[[], List[dict]], List[dict]] + - List[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], List[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + stat_file_path : Optional[DPPath] + The path to the stat file. + + """ + raise NotImplementedError + def forward( self, descriptor: torch.Tensor, diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 152c69a444..ef8a53e656 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import functools import logging import time from copy import ( @@ -49,6 +50,9 @@ from deepmd.pt.utils.learning_rate import ( LearningRateExp, ) +from deepmd.pt.utils.stat import ( + make_stat_input, +) if torch.__version__.startswith("2"): import torch._dynamo @@ -59,6 +63,10 @@ DataLoader, ) +from deepmd.utils.path import ( + DPH5Path, +) + log = logging.getLogger(__name__) @@ -67,7 +75,6 @@ def __init__( self, config: Dict[str, Any], training_data, - sampled=None, stat_file_path=None, validation_data=None, init_model=None, @@ -82,7 +89,15 @@ def __init__( Args: - config: The Dict-like configuration with training options. """ - resume_model = init_model if init_model is not None else restart_model + if init_model is not None: + resume_model = init_model + elif restart_model is not None: + resume_model = restart_model + elif finetune_model is not None: + resume_model = finetune_model + else: + resume_model = None + resuming = resume_model is not None self.restart_training = restart_model is not None model_params = config["model"] training_params = config["training"] @@ -93,8 +108,6 @@ def __init__( self.model_keys = ( list(model_params["model_dict"]) if self.multi_task else ["Default"] ) - if self.multi_task and sampled is None: - sampled = {key: None for key in self.model_keys} self.rank = dist.get_rank() if dist.is_initialized() else 0 self.world_size = dist.get_world_size() if dist.is_initialized() else 1 self.num_model = len(self.model_keys) @@ -119,62 +132,51 @@ def get_opt_param(params): return opt_type, opt_param def get_data_loader(_training_data, _validation_data, _training_params): - if "auto_prob" in _training_params["training_data"]: - train_sampler = get_weighted_sampler( - _training_data, _training_params["training_data"]["auto_prob"] - ) - elif "sys_probs" in _training_params["training_data"]: - train_sampler = get_weighted_sampler( - _training_data, - _training_params["training_data"]["sys_probs"], - sys_prob=True, + def get_dataloader_and_buffer(_data, _params): + if "auto_prob" in _training_params["training_data"]: + _sampler = get_weighted_sampler( + _data, _params["training_data"]["auto_prob"] + ) + elif "sys_probs" in _training_params["training_data"]: + _sampler = get_weighted_sampler( + _data, + _params["training_data"]["sys_probs"], + sys_prob=True, + ) + else: + _sampler = get_weighted_sampler(_data, "prob_sys_size") + + if _sampler is None: + log.warning( + "Sampler not specified!" + ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration. + _dataloader = DataLoader( + _data, + sampler=_sampler, + batch_size=None, + num_workers=NUM_WORKERS, # setting to 0 diverges the behavior of its iterator; should be >=1 + drop_last=False, + pin_memory=True, ) - else: - train_sampler = get_weighted_sampler(_training_data, "prob_sys_size") + with torch.device("cpu"): + _data_buffered = BufferedIterator(iter(_dataloader)) + return _dataloader, _data_buffered - if "auto_prob" in _training_params["validation_data"]: - valid_sampler = get_weighted_sampler( - _validation_data, _training_params["validation_data"]["auto_prob"] - ) - elif "sys_probs" in _training_params["validation_data"]: - valid_sampler = get_weighted_sampler( - _validation_data, - _training_params["validation_data"]["sys_probs"], - sys_prob=True, - ) - else: - valid_sampler = get_weighted_sampler(_validation_data, "prob_sys_size") - - if train_sampler is None or valid_sampler is None: - log.warning( - "Sampler not specified!" - ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration. - training_dataloader = DataLoader( - _training_data, - sampler=train_sampler, - batch_size=None, - num_workers=NUM_WORKERS, # setting to 0 diverges the behavior of its iterator; should be >=1 - drop_last=False, - pin_memory=True, - ) - with torch.device("cpu"): - training_data_buffered = BufferedIterator(iter(training_dataloader)) - validation_dataloader = DataLoader( - _validation_data, - sampler=valid_sampler, - batch_size=None, - num_workers=min(NUM_WORKERS, 1), - drop_last=False, - pin_memory=True, + training_dataloader, training_data_buffered = get_dataloader_and_buffer( + _training_data, _training_params ) - with torch.device("cpu"): - validation_data_buffered = BufferedIterator(iter(validation_dataloader)) - if _training_params.get("validation_data", None) is not None: + if _validation_data is not None: + ( + validation_dataloader, + validation_data_buffered, + ) = get_dataloader_and_buffer(_validation_data, _training_params) valid_numb_batch = _training_params["validation_data"].get( "numb_btch", 1 ) else: + validation_dataloader = None + validation_data_buffered = None valid_numb_batch = 1 return ( training_dataloader, @@ -184,13 +186,34 @@ def get_data_loader(_training_data, _validation_data, _training_params): valid_numb_batch, ) - def get_single_model(_model_params, _sampled, _stat_file_path): + def get_single_model( + _model_params, + _training_data, + _validation_data, + _stat_file_path, + _data_requirement, + ): model = get_model(deepcopy(_model_params)).to(DEVICE) - if not model_params.get("resuming", False): + _training_data.add_data_requirement(_data_requirement) + if _validation_data is not None: + _validation_data.add_data_requirement(_data_requirement) + if not resuming and self.rank == 0: + + @functools.lru_cache + def get_sample(): + sampled = make_stat_input( + _training_data.systems, + _training_data.dataloaders, + _model_params.get("data_stat_nbatch", 10), + ) + return sampled + model.compute_or_load_stat( - sampled=_sampled, + sampled_func=get_sample, stat_file_path=_stat_file_path, ) + if isinstance(_stat_file_path, DPH5Path): + _stat_file_path.root.close() return model def get_lr(lr_params): @@ -230,9 +253,34 @@ def get_loss(loss_params, start_lr, _ntypes): else: self.opt_type, self.opt_param = get_opt_param(training_params) + # Loss + if not self.multi_task: + self.loss = get_loss( + config["loss"], + config["learning_rate"]["start_lr"], + len(model_params["type_map"]), + ) + else: + self.loss = {} + for model_key in self.model_keys: + loss_param = config["loss_dict"][model_key] + if config.get("learning_rate_dict", None) is not None: + lr_param = config["learning_rate_dict"][model_key]["start_lr"] + else: + lr_param = config["learning_rate"]["start_lr"] + ntypes = len(model_params["model_dict"][model_key]["type_map"]) + self.loss[model_key] = get_loss(loss_param, lr_param, ntypes) + # Data + Model dp_random.seed(training_params["seed"]) if not self.multi_task: + self.model = get_single_model( + model_params, + training_data, + validation_data, + stat_file_path, + self.loss.label_requirement, + ) ( self.training_dataloader, self.training_data, @@ -240,7 +288,6 @@ def get_loss(loss_params, start_lr, _ntypes): self.validation_data, self.valid_numb_batch, ) = get_data_loader(training_data, validation_data, training_params) - self.model = get_single_model(model_params, sampled, stat_file_path) else: ( self.training_dataloader, @@ -251,6 +298,13 @@ def get_loss(loss_params, start_lr, _ntypes): self.model, ) = {}, {}, {}, {}, {}, {} for model_key in self.model_keys: + self.model[model_key] = get_single_model( + model_params["model_dict"][model_key], + training_data[model_key], + validation_data[model_key], + stat_file_path[model_key], + self.loss[model_key].label_requirement, + ) ( self.training_dataloader[model_key], self.training_data[model_key], @@ -262,11 +316,6 @@ def get_loss(loss_params, start_lr, _ntypes): validation_data[model_key], training_params["data_dict"][model_key], ) - self.model[model_key] = get_single_model( - model_params["model_dict"][model_key], - sampled[model_key], - stat_file_path[model_key], - ) # Learning rate self.warmup_steps = training_params.get("warmup_steps", 0) @@ -281,24 +330,6 @@ def get_loss(loss_params, start_lr, _ntypes): else: self.lr_exp = get_lr(config["learning_rate"]) - # Loss - if not self.multi_task: - self.loss = get_loss( - config["loss"], - config["learning_rate"]["start_lr"], - len(model_params["type_map"]), - ) - else: - self.loss = {} - for model_key in self.model_keys: - loss_param = config["loss_dict"][model_key] - if config.get("learning_rate_dict", None) is not None: - lr_param = config["learning_rate_dict"][model_key]["start_lr"] - else: - lr_param = config["learning_rate"]["start_lr"] - ntypes = len(model_params["model_dict"][model_key]["type_map"]) - self.loss[model_key] = get_loss(loss_param, lr_param, ntypes) - # JIT if JIT: self.model = torch.jit.script(self.model) @@ -309,7 +340,7 @@ def get_loss(loss_params, start_lr, _ntypes): # resuming and finetune optimizer_state_dict = None - if model_params["resuming"]: + if resuming: ntest = model_params.get("data_bias_nsample", 1) origin_model = ( finetune_model if finetune_model is not None else resume_model @@ -404,7 +435,7 @@ def get_loss(loss_params, start_lr, _ntypes): # Multi-task share params if shared_links is not None: - self.wrapper.share_params(shared_links, resume=model_params["resuming"]) + self.wrapper.share_params(shared_links, resume=resuming or self.rank != 0) if dist.is_initialized(): torch.cuda.set_device(LOCAL_RANK) @@ -617,6 +648,9 @@ def log_loss_valid(_task_key="Default"): input_dict, label_dict, _ = self.get_data( is_train=False, task_key=_task_key ) + if input_dict == {}: + # no validation data + return "", None _, loss, more_loss = self.wrapper( **input_dict, cur_lr=pref_lr, @@ -778,6 +812,8 @@ def get_data(self, is_train=True, task_key="Default"): ) batch_data = next(iter(self.training_data)) else: + if self.validation_data is None: + return {}, {}, {} try: batch_data = next(iter(self.validation_data)) except StopIteration: @@ -796,6 +832,8 @@ def get_data(self, is_train=True, task_key="Default"): ) batch_data = next(iter(self.training_data[task_key])) else: + if self.validation_data[task_key] is None: + return {}, {}, {} try: batch_data = next(iter(self.validation_data[task_key])) except StopIteration: @@ -812,28 +850,24 @@ def get_data(self, is_train=True, task_key="Default"): batch_data[key] = batch_data[key].to(DEVICE) else: batch_data[key] = [item.to(DEVICE) for item in batch_data[key]] - input_dict = {} - for item in [ + # we may need a better way to classify which are inputs and which are labels + # now wrapper only supports the following inputs: + input_keys = [ "coord", "atype", "box", - ]: - if item in batch_data: - input_dict[item] = batch_data[item] - else: - input_dict[item] = None + "spin", + "fparam", + "aparam", + ] + input_dict = {item_key: None for item_key in input_keys} label_dict = {} - for item in [ - "energy", - "force", - "virial", - "clean_coord", - "clean_type", - "coord_mask", - "type_mask", - ]: - if item in batch_data: - label_dict[item] = batch_data[item] + for item_key in batch_data: + if item_key in input_keys: + input_dict[item_key] = batch_data[item_key] + else: + if item_key not in ["sid", "fid"] and "find_" not in item_key: + label_dict[item_key] = batch_data[item_key] log_dict = {} if "fid" in batch_data: log_dict["fid"] = batch_data["fid"] diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index 74b4a83ce7..67f8043653 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -61,7 +61,7 @@ def __init__( self.inference_only = self.loss is None def set_trainable_params(self): - supported_types = ["type_embedding", "descriptor", "fitting_net"] + supported_types = ["descriptor", "fitting_net"] for model_item in self.model: for net_type in supported_types: trainable = True @@ -83,7 +83,12 @@ def set_trainable_params(self): param.requires_grad = trainable def share_params(self, shared_links, resume=False): - supported_types = ["type_embedding", "descriptor", "fitting_net"] + """ + Share the parameters of classes following rules defined in shared_links during multitask training. + If not start from checkpoint (resume is False), + some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + supported_types = ["descriptor", "fitting_net"] for shared_item in shared_links: class_name = shared_links[shared_item]["type"] shared_base = shared_links[shared_item]["links"][0] @@ -159,6 +164,7 @@ def forward( coord, atype, box: Optional[torch.Tensor] = None, + spin: Optional[torch.Tensor] = None, cur_lr: Optional[torch.Tensor] = None, label: Optional[torch.Tensor] = None, task_key: Optional[torch.Tensor] = None, diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 2125f9cdee..65a96418c9 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -35,6 +35,9 @@ from deepmd.pt.utils.dataset import ( DeepmdDataSetForLoader, ) +from deepmd.utils.data import ( + DataRequirementItem, +) from deepmd.utils.data_system import ( prob_sys_size_ext, process_sys_probs, @@ -147,6 +150,11 @@ def __getitem__(self, idx): batch["sid"] = idx return batch + def add_data_requirement(self, data_requirement: List[DataRequirementItem]): + """Add data requirement for each system in multiple systems.""" + for system in self.systems: + system.add_data_requirement(data_requirement) + _sentinel = object() QUEUESIZE = 32 @@ -248,7 +256,7 @@ def get_weighted_sampler(training_data, prob_style, sys_prob=False): probs = prob_sys_size_ext(style, len(training_data), training_data.index) else: probs = process_sys_probs(prob_style, training_data.index) - log.info("Generated weighted sampler with prob array: " + str(probs)) + log.debug("Generated weighted sampler with prob array: " + str(probs)) # training_data.total_batch is the size of one epoch, you can increase it to avoid too many rebuilding of iteraters len_sampler = training_data.total_batch * max(env.NUM_WORKERS, 1) with torch.device("cpu"): diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 4619b6417f..40a513acdf 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later + +from typing import ( + List, +) + from torch.utils.data import ( Dataset, ) from deepmd.utils.data import ( + DataRequirementItem, DeepmdData, ) @@ -27,9 +33,6 @@ def __init__( self._data_system = DeepmdData( sys_path=system, shuffle_test=shuffle, type_map=self._type_map ) - self._data_system.add("energy", 1, atomic=False, must=False, high_prec=True) - self._data_system.add("force", 3, atomic=True, must=False, high_prec=False) - self._data_system.add("virial", 9, atomic=False, must=False, high_prec=False) self.mixed_type = self._data_system.mixed_type self._ntypes = self._data_system.get_ntypes() self._natoms = self._data_system.get_natoms() @@ -43,3 +46,18 @@ def __getitem__(self, index): b_data = self._data_system.get_item_torch(index) b_data["natoms"] = self._natoms_vec return b_data + + def add_data_requirement(self, data_requirement: List[DataRequirementItem]): + """Add data requirement for this data system.""" + for data_item in data_requirement: + self._data_system.add( + data_item["key"], + data_item["ndof"], + atomic=data_item["atomic"], + must=data_item["must"], + high_prec=data_item["high_prec"], + type_sel=data_item["type_sel"], + repeat=data_item["repeat"], + default=data_item["default"], + dtype=data_item["dtype"], + ) diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index 13749da151..c8fa1e5185 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -19,9 +19,7 @@ def change_finetune_model_params( - ckpt & finetune_model: origin model. - config: Read from json file. """ - if multi_task: - # TODO - log.error("finetune mode need modification for multitask mode!") + # TODO need support for multitask mode if finetune_model is not None: state_dict = torch.load(finetune_model, map_location=env.DEVICE) if "model" in state_dict: diff --git a/deepmd/pt/utils/multi_task.py b/deepmd/pt/utils/multi_task.py index f97a826b03..ae3933a101 100644 --- a/deepmd/pt/utils/multi_task.py +++ b/deepmd/pt/utils/multi_task.py @@ -4,17 +4,10 @@ ) from deepmd.pt.model.descriptor import ( - DescrptDPA1, - DescrptDPA2, - DescrptSeA, -) -from deepmd.pt.model.network.network import ( - TypeEmbedNet, + BaseDescriptor, ) from deepmd.pt.model.task import ( - EnergyFittingNet, - EnergyFittingNetDirect, - FittingNetAttenLcc, + BaseFitting, ) @@ -37,9 +30,68 @@ def preprocess_shared_params(model_config): - "shared_level": Shared level (int) of this item in this model. Lower for more params to share, 0 means to share all params in this item. This list are sorted by "shared_level". + For example, if one has `model_config` like this: + "model": { + "shared_dict": { + "my_type_map": ["foo", "bar"], + "my_des1": { + "type": "se_e2_a", + "neuron": [10, 20, 40] + }, + }, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_des1", + "fitting_net": { + "neuron": [100, 100, 100] + } + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_des1", + "fitting_net": { + "neuron": [100, 100, 100] + } + } + "model_3": { + "type_map": "my_type_map", + "descriptor": "my_des1:1", + "fitting_net": { + "neuron": [100, 100, 100] + } + } + } + } + The above config will init three model branches named `model_1` and `model_2` and `model_3`, + in which: + - `model_2` and `model_3` will have the same `type_map` as that in `model_1`. + - `model_2` will share all the parameters of `descriptor` with `model_1`, + while `model_3` will share part of parameters of `descriptor` with `model_1` + on human-defined share-level `1` (default is `0`, meaning share all the parameters). + - `model_1`, `model_2` and `model_3` have three different `fitting_net`s. + The returned `model_config` will automatically fulfill the input `model_config` as if there's no sharing, + and the `shared_links` will keep all the sharing information with looking: + { + 'my_des1': { + 'type': 'DescrptSeA', + 'links': [ + {'model_key': 'model_1', + 'shared_type': 'descriptor', + 'shared_level': 0}, + {'model_key': 'model_2', + 'shared_type': 'descriptor', + 'shared_level': 0}, + {'model_key': 'model_3', + 'shared_type': 'descriptor', + 'shared_level': 1} + ] + } + } + """ assert "model_dict" in model_config, "only multi-task model can use this method!" - supported_types = ["type_map", "type_embedding", "descriptor", "fitting_net"] + supported_types = ["type_map", "descriptor", "fitting_net"] shared_dict = model_config.get("shared_dict", {}) shared_links = {} type_map_keys = [] @@ -98,32 +150,9 @@ def replace_one_item(params_dict, key_type, key_in_dict, suffix="", index=None): def get_class_name(item_key, item_params): - if item_key == "type_embedding": - return TypeEmbedNet.__name__ - elif item_key == "descriptor": - item_type = item_params.get("type", "se_e2_a") - if item_type == "se_e2_a": - return DescrptSeA.__name__ - elif item_type in ["se_atten", "dpa1"]: - return DescrptDPA1.__name__ - elif item_type in ["dpa2"]: - return DescrptDPA2.__name__ - # todo add support for other combination - # elif item_type == "gaussian_lcc": - # return DescrptGaussianLcc.__name__ - # elif item_type == "hybrid": - # return DescrptHybrid.__name__ - else: - raise RuntimeError(f"Unknown descriptor type {item_type}") + if item_key == "descriptor": + return BaseDescriptor.get_class_by_type(item_params.get("type", "se_e2_a")) elif item_key == "fitting_net": - item_type = item_params.get("type", "ener") - if item_type == "ener": - return EnergyFittingNet.__name__ - elif item_type in ["direct_force", "direct_force_ener"]: - return EnergyFittingNetDirect.__name__ - elif item_type == "atten_vec_lcc": - return FittingNetAttenLcc.__name__ - else: - raise RuntimeError(f"Unknown fitting_net type {item_type}") + return BaseFitting.get_class_by_type(item_params.get("type", "ener")) else: raise RuntimeError(f"Unknown class_name type {item_key}") diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 4c769f019e..3b246a0ec2 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -3,6 +3,10 @@ import torch +from deepmd.pt.utils.utils import ( + dict_to_device, +) + log = logging.getLogger(__name__) @@ -18,19 +22,9 @@ def make_stat_input(datasets, dataloaders, nbatches): - a list of dicts, each of which contains data from a system """ lst = [] - keys = [ - "coord", - "force", - "energy", - "atype", - "box", - "natoms", - ] - if datasets[0].mixed_type: - keys.append("real_natoms_vec") log.info(f"Packing data for statistics from {len(datasets)} systems") for i in range(len(datasets)): - sys_stat = {key: [] for key in keys} + sys_stat = {} with torch.device("cpu"): iterator = iter(dataloaders[i]) for _ in range(nbatches): @@ -40,19 +34,19 @@ def make_stat_input(datasets, dataloaders, nbatches): iterator = iter(dataloaders[i]) stat_data = next(iterator) for dd in stat_data: - if dd in keys: + if stat_data[dd] is None: + sys_stat[dd] = None + elif isinstance(stat_data[dd], torch.Tensor): + if dd not in sys_stat: + sys_stat[dd] = [] sys_stat[dd].append(stat_data[dd]) - for key in keys: - if not isinstance(sys_stat[key][0], list): - if sys_stat[key][0] is None: - sys_stat[key] = None - else: - sys_stat[key] = torch.cat(sys_stat[key], dim=0) + else: + pass + for key in sys_stat: + if sys_stat[key] is None or sys_stat[key][0] is None: + sys_stat[key] = None else: - sys_stat_list = [] - for ii, _ in enumerate(sys_stat[key][0]): - tmp_stat = [x[ii] for x in sys_stat[key]] - sys_stat_list.append(torch.cat(tmp_stat, dim=0)) - sys_stat[key] = sys_stat_list + sys_stat[key] = torch.cat(sys_stat[key], dim=0) + dict_to_device(sys_stat) lst.append(sys_stat) return lst diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 6e0c47881f..03e39e1f21 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -490,6 +490,8 @@ def reformat_data_torch(self, data): if self.data_dict[kk]["atomic"]: data[kk] = data[kk].reshape(-1, self.data_dict[kk]["ndof"]) data["atype"] = data["type"] + if not self.pbc: + data["box"] = None return data def _load_set(self, set_name: DPPath): @@ -664,3 +666,73 @@ def _check_pbc(self, sys_path: DPPath): def _check_mode(self, set_path: DPPath): return (set_path / "real_atom_types.npy").is_file() + + +class DataRequirementItem: + """A class to store the data requirement for data systems. + + Parameters + ---------- + key + The key of the item. The corresponding data is stored in `sys_path/set.*/key.npy` + ndof + The number of dof + atomic + The item is an atomic property. + If False, the size of the data should be nframes x ndof + If True, the size of data should be nframes x natoms x ndof + must + The data file `sys_path/set.*/key.npy` must exist. + If must is False and the data file does not exist, the `data_dict[find_key]` is set to 0.0 + high_prec + Load the data and store in float64, otherwise in float32 + type_sel + Select certain type of atoms + repeat + The data will be repeated `repeat` times. + default : float, default=0. + default value of data + dtype : np.dtype, optional + the dtype of data, overwrites `high_prec` if provided + """ + + def __init__( + self, + key: str, + ndof: int, + atomic: bool = False, + must: bool = False, + high_prec: bool = False, + type_sel: Optional[List[int]] = None, + repeat: int = 1, + default: float = 0.0, + dtype: Optional[np.dtype] = None, + ) -> None: + self.key = key + self.ndof = ndof + self.atomic = atomic + self.must = must + self.high_prec = high_prec + self.type_sel = type_sel + self.repeat = repeat + self.default = default + self.dtype = dtype + self.dict = self.to_dict() + + def to_dict(self) -> dict: + return { + "key": self.key, + "ndof": self.ndof, + "atomic": self.atomic, + "must": self.must, + "high_prec": self.high_prec, + "type_sel": self.type_sel, + "repeat": self.repeat, + "default": self.default, + "dtype": self.dtype, + } + + def __getitem__(self, key: str): + if key not in self.dict: + raise KeyError(key) + return self.dict[key] diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py index 2fa497b9b6..217c46844b 100644 --- a/deepmd/utils/env_mat_stat.py +++ b/deepmd/utils/env_mat_stat.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging from abc import ( ABC, abstractmethod, @@ -19,6 +20,8 @@ DPPath, ) +log = logging.getLogger(__name__) + class StatItem: """A class to store the statistics of the environment matrix. @@ -170,10 +173,12 @@ def load_or_compute_stats( """ if path is not None and path.is_dir(): self.load_stats(path) + log.info(f"Load stats from {path}.") else: self.compute_stats(data) if path is not None: self.save_stats(path) + log.info(f"Save stats to {path}.") def get_avg(self, default: float = 0) -> Dict[str, float]: """Get the average of the environment matrix. diff --git a/source/tests/pt/model/test_descriptor.py b/source/tests/pt/model/test_descriptor.py index ffad27201a..7d21d1c13d 100644 --- a/source/tests/pt/model/test_descriptor.py +++ b/source/tests/pt/model/test_descriptor.py @@ -38,6 +38,9 @@ op_module, ) +from ..test_stat import ( + energy_data_requirement, +) from .test_embedding_net import ( get_single_batch, ) @@ -114,6 +117,7 @@ def setUp(self): self.systems[0], model_config["type_map"], ) + ds.add_data_requirement(energy_data_requirement) self.np_batch, self.pt_batch = get_single_batch(ds) self.sec = np.cumsum(self.sel) self.ntypes = len(self.sel) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index fcdd408726..fa4be9171c 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -79,7 +79,6 @@ def test_consistency( [0, 4], ): ft0 = DipoleFittingNet( - "foo", self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), @@ -115,12 +114,12 @@ def test_consistency( ) ret2 = ft2(rd0, atype, gr, fparam=ifp, aparam=iap) np.testing.assert_allclose( - to_numpy_array(ret0["foo"]), - ret1["foo"], + to_numpy_array(ret0["dipole"]), + ret1["dipole"], ) np.testing.assert_allclose( - to_numpy_array(ret0["foo"]), - to_numpy_array(ret2["foo"]), + to_numpy_array(ret0["dipole"]), + to_numpy_array(ret2["dipole"]), ) def test_jit( @@ -132,7 +131,6 @@ def test_jit( [0, 4], ): ft0 = DipoleFittingNet( - "foo", self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), @@ -168,7 +166,6 @@ def test_rot(self): [0, 4], ): ft0 = DipoleFittingNet( - "foo", 3, # ntype self.dd0.dim_out, # dim_descrpt embedding_width=self.dd0.get_dim_emb(), @@ -209,7 +206,7 @@ def test_rot(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) - res.append(ret0["foo"]) + res.append(ret0["dipole"]) np.testing.assert_allclose( to_numpy_array(res[1]), to_numpy_array(torch.matmul(res[0], rmat)) @@ -218,7 +215,6 @@ def test_rot(self): def test_permu(self): coord = torch.matmul(self.coord, self.cell) ft0 = DipoleFittingNet( - "foo", 3, # ntype self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), @@ -245,7 +241,7 @@ def test_permu(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) - res.append(ret0["foo"]) + res.append(ret0["dipole"]) np.testing.assert_allclose( to_numpy_array(res[0][:, idx_perm]), to_numpy_array(res[1]) @@ -260,7 +256,6 @@ def test_trans(self): self.cell, ) ft0 = DipoleFittingNet( - "foo", 3, # ntype self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), @@ -286,7 +281,7 @@ def test_trans(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) - res.append(ret0["foo"]) + res.append(ret0["dipole"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -305,7 +300,6 @@ def setUp(self): self.atype = torch.IntTensor([0, 0, 0, 1, 1], device="cpu").to(env.DEVICE) self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) self.ft0 = DipoleFittingNet( - "dipole", self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), diff --git a/source/tests/pt/model/test_embedding_net.py b/source/tests/pt/model/test_embedding_net.py index 87e8a97444..a1895718dd 100644 --- a/source/tests/pt/model/test_embedding_net.py +++ b/source/tests/pt/model/test_embedding_net.py @@ -39,6 +39,10 @@ ) from deepmd.tf.descriptor import DescrptSeA as DescrptSeA_tf +from ..test_stat import ( + energy_data_requirement, +) + CUR_DIR = os.path.dirname(__file__) @@ -128,6 +132,7 @@ def setUp(self): self.systems[0], model_config["type_map"], ) + ds.add_data_requirement(energy_data_requirement) self.filter_neuron = model_config["descriptor"]["neuron"] self.axis_neuron = model_config["descriptor"]["axis_neuron"] self.np_batch, self.torch_batch = get_single_batch(ds) diff --git a/source/tests/pt/model/test_model.py b/source/tests/pt/model/test_model.py index d8c7de39c3..69ec88f5d7 100644 --- a/source/tests/pt/model/test_model.py +++ b/source/tests/pt/model/test_model.py @@ -51,6 +51,10 @@ LearningRateExp, ) +from ..test_stat import ( + energy_data_requirement, +) + VariableState = collections.namedtuple("VariableState", ["value", "gradient"]) @@ -281,6 +285,7 @@ def test_consistency(self): "type_map": self.type_map, }, ) + my_ds.add_data_requirement(energy_data_requirement) my_model = get_model( model_params={ "descriptor": { diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index f76a9e28ac..b1a5e3f730 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -67,7 +67,6 @@ def test_consistency( [None, self.scale], ): ft0 = PolarFittingNet( - "foo", self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), @@ -113,16 +112,16 @@ def test_consistency( aparam=to_numpy_array(iap), ) np.testing.assert_allclose( - to_numpy_array(ret0["foo"]), - ret1["foo"], + to_numpy_array(ret0["polar"]), + ret1["polar"], ) np.testing.assert_allclose( - to_numpy_array(ret0["foo"]), - to_numpy_array(ret2["foo"]), + to_numpy_array(ret0["polar"]), + to_numpy_array(ret2["polar"]), ) np.testing.assert_allclose( - to_numpy_array(ret0["foo"]), - ret3["foo"], + to_numpy_array(ret0["polar"]), + ret3["polar"], ) def test_jit( @@ -135,7 +134,6 @@ def test_jit( [True, False], ): ft0 = PolarFittingNet( - "foo", self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), @@ -177,7 +175,6 @@ def test_rot(self): [None, self.scale], ): ft0 = PolarFittingNet( - "foo", self.nt, self.dd0.dim_out, # dim_descrpt embedding_width=self.dd0.get_dim_emb(), @@ -220,7 +217,7 @@ def test_rot(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) - res.append(ret0["foo"]) + res.append(ret0["polar"]) np.testing.assert_allclose( to_numpy_array(res[1]), to_numpy_array( @@ -235,7 +232,6 @@ def test_permu(self): coord = torch.matmul(self.coord, self.cell) for fit_diag, scale in itertools.product([True, False], [None, self.scale]): ft0 = PolarFittingNet( - "foo", self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), @@ -264,7 +260,7 @@ def test_permu(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=None, aparam=None) - res.append(ret0["foo"]) + res.append(ret0["polar"]) np.testing.assert_allclose( to_numpy_array(res[0][:, idx_perm]), @@ -281,7 +277,6 @@ def test_trans(self): ) for fit_diag, scale in itertools.product([True, False], [None, self.scale]): ft0 = PolarFittingNet( - "foo", self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), @@ -309,7 +304,7 @@ def test_trans(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) - res.append(ret0["foo"]) + res.append(ret0["polar"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -328,7 +323,6 @@ def setUp(self): self.atype = torch.IntTensor([0, 0, 0, 1, 1], device="cpu") self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) self.ft0 = PolarFittingNet( - "polar", self.nt, self.dd0.dim_out, embedding_width=self.dd0.get_dim_emb(), diff --git a/source/tests/pt/model/water/multitask.json b/source/tests/pt/model/water/multitask.json new file mode 100644 index 0000000000..6baddd672b --- /dev/null +++ b/source/tests/pt/model/water/multitask.json @@ -0,0 +1,139 @@ +{ + "model": { + "shared_dict": { + "my_type_map": [ + "O", + "H", + "B" + ], + "my_descriptor": { + "type": "se_e2_a", + "sel": [ + 46, + 92 + ], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "seed": 1, + "_comment": " that's all" + }, + "_comment": "that's all" + }, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + } + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + } + } + } + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.0002, + "decay_rate": 0.98, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss_dict": { + "_comment": " that's all", + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0 + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0 + } + }, + "training": { + "model_prob": { + "model_1": 0.5, + "model_2": 0.5 + }, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1", + "training_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + } + }, + "model_2": { + "stat_file": "./stat_files/model_2", + "training_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + } + } + }, + "numb_steps": 100000, + "warmup_steps": 0, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100, + "_comment": "that's all" + } +} diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index e117c7f05a..484d62a3ad 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -28,6 +28,9 @@ from .model.test_embedding_net import ( get_single_batch, ) +from .test_stat import ( + energy_data_requirement, +) CUR_DIR = os.path.dirname(__file__) @@ -47,6 +50,7 @@ def get_batch(): if isinstance(systems, str): systems = expand_sys_str(systems) dataset = DeepmdDataSetForLoader(systems[0], model_config["type_map"]) + dataset.add_data_requirement(energy_data_requirement) np_batch, pt_batch = get_single_batch(dataset) return np_batch, pt_batch diff --git a/source/tests/pt/test_multitask.py b/source/tests/pt/test_multitask.py new file mode 100644 index 0000000000..3c0240dbdc --- /dev/null +++ b/source/tests/pt/test_multitask.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import shutil +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +import torch + +from deepmd.pt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt.utils.multi_task import ( + preprocess_shared_params, +) + +from .model.test_permutation import ( + model_dpa1, + model_dpa2, + model_se_e2_a, +) + +multitask_template_json = str(Path(__file__).parent / "water/multitask.json") +with open(multitask_template_json) as f: + multitask_template = json.load(f) + + +class MultiTaskTrainTest: + def test_multitask_train(self): + trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links) + trainer.run() + # check model keys + self.assertEqual(len(trainer.wrapper.model), 2) + self.assertIn("model_1", trainer.wrapper.model) + self.assertIn("model_2", trainer.wrapper.model) + + # check shared parameters + multi_state_dict = trainer.wrapper.model.state_dict() + for state_key in multi_state_dict: + if "model_1" in state_key: + self.assertIn(state_key.replace("model_1", "model_2"), multi_state_dict) + if "model_2" in state_key: + self.assertIn(state_key.replace("model_2", "model_1"), multi_state_dict) + if "model_1.descriptor" in state_key: + torch.testing.assert_allclose( + multi_state_dict[state_key], + multi_state_dict[state_key.replace("model_1", "model_2")], + ) + self.tearDown() + + def tearDown(self): + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out"]: + os.remove(f) + if f in [self.stat_files]: + shutil.rmtree(f) + + +class TestMultiTaskSeA(unittest.TestCase, MultiTaskTrainTest): + def setUp(self): + multitask_se_e2_a = deepcopy(multitask_template) + multitask_se_e2_a["model"]["shared_dict"]["my_descriptor"] = model_se_e2_a[ + "descriptor" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.stat_files = "se_e2_a" + os.makedirs(self.stat_files, exist_ok=True) + self.config = multitask_se_e2_a + self.config["training"]["data_dict"]["model_1"]["training_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"][ + "stat_file" + ] = f"{self.stat_files}/model_1" + self.config["training"]["data_dict"]["model_2"]["training_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"][ + "stat_file" + ] = f"{self.stat_files}/model_2" + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + + def tearDown(self) -> None: + MultiTaskTrainTest.tearDown(self) + + +class TestMultiTaskDPA1(unittest.TestCase, MultiTaskTrainTest): + def setUp(self): + multitask_DPA1 = deepcopy(multitask_template) + multitask_DPA1["model"]["shared_dict"]["my_descriptor"] = model_dpa1[ + "descriptor" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.stat_files = "DPA1" + os.makedirs(self.stat_files, exist_ok=True) + self.config = multitask_DPA1 + self.config["training"]["data_dict"]["model_1"]["training_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"][ + "stat_file" + ] = f"{self.stat_files}/model_1" + self.config["training"]["data_dict"]["model_2"]["training_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"][ + "stat_file" + ] = f"{self.stat_files}/model_2" + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + + def tearDown(self) -> None: + MultiTaskTrainTest.tearDown(self) + + +class TestMultiTaskDPA2(unittest.TestCase, MultiTaskTrainTest): + def setUp(self): + multitask_DPA2 = deepcopy(multitask_template) + multitask_DPA2["model"]["shared_dict"]["my_descriptor"] = model_dpa2[ + "descriptor" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.stat_files = "DPA2" + os.makedirs(self.stat_files, exist_ok=True) + self.config = multitask_DPA2 + self.config["training"]["data_dict"]["model_1"]["training_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"][ + "stat_file" + ] = f"{self.stat_files}/model_1" + self.config["training"]["data_dict"]["model_2"]["training_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"][ + "stat_file" + ] = f"{self.stat_files}/model_2" + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + + def tearDown(self) -> None: + MultiTaskTrainTest.tearDown(self) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py index 98d4e59d95..3a09f82baf 100644 --- a/source/tests/pt/test_stat.py +++ b/source/tests/pt/test_stat.py @@ -44,9 +44,51 @@ from deepmd.tf.utils.data_system import ( DeepmdDataSystem, ) +from deepmd.utils.data import ( + DataRequirementItem, +) CUR_DIR = os.path.dirname(__file__) +energy_data_requirement = [ + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + must=False, + high_prec=True, + ), + DataRequirementItem( + "force", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "virial", + ndof=9, + atomic=False, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_ener", + ndof=1, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_pref", + ndof=1, + atomic=True, + must=False, + high_prec=False, + repeat=3, + ), +] + def compare(ut, base, given): if isinstance(base, list): @@ -111,6 +153,7 @@ def setUp(self): self.filter_neuron = model_config["descriptor"]["neuron"] self.axis_neuron = model_config["descriptor"]["axis_neuron"] self.n_neuron = model_config["fitting_net"]["neuron"] + self.my_dataset.add_data_requirement(energy_data_requirement) self.my_sampled = my_make( self.my_dataset.systems, self.my_dataset.dataloaders, self.data_stat_nbatch @@ -181,8 +224,6 @@ def test_descriptor(self): for sys in sampled: for key in [ "coord", - "force", - "energy", "atype", "natoms", "box", diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index f86691cde6..4e73fc4f8a 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -79,15 +79,6 @@ def setUp(self): self.config["training"]["training_data"]["systems"] = data_file self.config["training"]["validation_data"]["systems"] = data_file self.config["model"] = deepcopy(model_dpa2) - # self.config["model"]["descriptor"]["rcut"] = self.config["model"]["descriptor"][ - # "repinit_rcut" - # ] - # self.config["model"]["descriptor"]["rcut_smth"] = self.config["model"][ - # "descriptor" - # ]["repinit_rcut_smth"] - # self.config["model"]["descriptor"]["sel"] = self.config["model"]["descriptor"][ - # "repinit_nsel" - # ] self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1