From a1f867217e3d06a9f5921cd5b2b76e42649b0882 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 14 Oct 2024 15:44:02 +0800 Subject: [PATCH] Chore: refactor get standard model (#4205) ## Summary by CodeRabbit - **Refactor** - Simplified model component creation by introducing a new function for better code clarity and reusability. - Updated model-building functions to utilize the new component creation logic, enhancing maintainability. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/pt/model/model/__init__.py | 96 +++++++++++-------------------- 1 file changed, 35 insertions(+), 61 deletions(-) diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 26aefa6201..613baf440e 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -72,6 +72,29 @@ ) +def _get_standard_model_components(model_params, ntypes): + # descriptor + model_params["descriptor"]["ntypes"] = ntypes + model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"]) + descriptor = BaseDescriptor(**model_params["descriptor"]) + # fitting + fitting_net = model_params.get("fitting_net", {}) + fitting_net["type"] = fitting_net.get("type", "ener") + fitting_net["ntypes"] = descriptor.get_ntypes() + fitting_net["type_map"] = copy.deepcopy(model_params["type_map"]) + fitting_net["mixed_types"] = descriptor.mixed_types() + if fitting_net["type"] in ["dipole", "polar"]: + fitting_net["embedding_width"] = descriptor.get_dim_emb() + fitting_net["dim_descrpt"] = descriptor.get_dim_out() + grad_force = "direct" not in fitting_net["type"] + if not grad_force: + fitting_net["out_dim"] = descriptor.get_dim_emb() + if "ener" in fitting_net["type"]: + fitting_net["return_energy"] = True + fitting = BaseFitting(**fitting_net) + return descriptor, fitting, fitting_net["type"] + + def get_spin_model(model_params): model_params = copy.deepcopy(model_params) if not model_params["spin"]["use_spin"] or isinstance( @@ -117,25 +140,9 @@ def get_linear_model(model_params): if "descriptor" in sub_model_params: # descriptor sub_model_params["descriptor"]["ntypes"] = ntypes - sub_model_params["descriptor"]["type_map"] = copy.deepcopy( - model_params["type_map"] + descriptor, fitting, _ = _get_standard_model_components( + sub_model_params, ntypes ) - descriptor = BaseDescriptor(**sub_model_params["descriptor"]) - # fitting - fitting_net = sub_model_params.get("fitting_net", {}) - fitting_net["type"] = fitting_net.get("type", "ener") - fitting_net["ntypes"] = descriptor.get_ntypes() - fitting_net["type_map"] = copy.deepcopy(model_params["type_map"]) - fitting_net["mixed_types"] = descriptor.mixed_types() - if fitting_net["type"] in ["dipole", "polar"]: - fitting_net["embedding_width"] = descriptor.get_dim_emb() - fitting_net["dim_descrpt"] = descriptor.get_dim_out() - grad_force = "direct" not in fitting_net["type"] - if not grad_force: - fitting_net["out_dim"] = descriptor.get_dim_emb() - if "ener" in fitting_net["type"]: - fitting_net["return_energy"] = True - fitting = BaseFitting(**fitting_net) list_of_models.append( DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"]) ) @@ -167,24 +174,7 @@ def get_linear_model(model_params): def get_zbl_model(model_params): model_params = copy.deepcopy(model_params) ntypes = len(model_params["type_map"]) - # descriptor - model_params["descriptor"]["ntypes"] = ntypes - model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"]) - descriptor = BaseDescriptor(**model_params["descriptor"]) - # fitting - fitting_net = model_params.get("fitting_net", None) - fitting_net["type"] = fitting_net.get("type", "ener") - fitting_net["ntypes"] = descriptor.get_ntypes() - fitting_net["type_map"] = copy.deepcopy(model_params["type_map"]) - fitting_net["mixed_types"] = descriptor.mixed_types() - fitting_net["embedding_width"] = descriptor.get_dim_out() - fitting_net["dim_descrpt"] = descriptor.get_dim_out() - grad_force = "direct" not in fitting_net["type"] - if not grad_force: - fitting_net["out_dim"] = descriptor.get_dim_emb() - if "ener" in fitting_net["type"]: - fitting_net["return_energy"] = True - fitting = BaseFitting(**fitting_net) + descriptor, fitting, _ = _get_standard_model_components(model_params, ntypes) dp_model = DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"]) # pairtab filepath = model_params["use_srtab"] @@ -246,25 +236,9 @@ def get_standard_model(model_params): model_params_old = model_params model_params = copy.deepcopy(model_params) ntypes = len(model_params["type_map"]) - # descriptor - model_params["descriptor"]["ntypes"] = ntypes - model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"]) - descriptor = BaseDescriptor(**model_params["descriptor"]) - # fitting - fitting_net = model_params.get("fitting_net", {}) - fitting_net["type"] = fitting_net.get("type", "ener") - fitting_net["ntypes"] = descriptor.get_ntypes() - fitting_net["type_map"] = copy.deepcopy(model_params["type_map"]) - fitting_net["mixed_types"] = descriptor.mixed_types() - if fitting_net["type"] in ["dipole", "polar"]: - fitting_net["embedding_width"] = descriptor.get_dim_emb() - fitting_net["dim_descrpt"] = descriptor.get_dim_out() - grad_force = "direct" not in fitting_net["type"] - if not grad_force: - fitting_net["out_dim"] = descriptor.get_dim_emb() - if "ener" in fitting_net["type"]: - fitting_net["return_energy"] = True - fitting = BaseFitting(**fitting_net) + descriptor, fitting, fitting_net_type = _get_standard_model_components( + model_params, ntypes + ) atom_exclude_types = model_params.get("atom_exclude_types", []) pair_exclude_types = model_params.get("pair_exclude_types", []) preset_out_bias = model_params.get("preset_out_bias") @@ -272,18 +246,18 @@ def get_standard_model(model_params): preset_out_bias, model_params["type_map"] ) - if fitting_net["type"] == "dipole": + if fitting_net_type == "dipole": modelcls = DipoleModel - elif fitting_net["type"] == "polar": + elif fitting_net_type == "polar": modelcls = PolarModel - elif fitting_net["type"] == "dos": + elif fitting_net_type == "dos": modelcls = DOSModel - elif fitting_net["type"] in ["ener", "direct_force_ener"]: + elif fitting_net_type in ["ener", "direct_force_ener"]: modelcls = EnergyModel - elif fitting_net["type"] == "property": + elif fitting_net_type == "property": modelcls = PropertyModel else: - raise RuntimeError(f"Unknown fitting type: {fitting_net['type']}") + raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") model = modelcls( descriptor=descriptor,