Skip to content

Commit

Permalink
Chore: refactor get standard model (#4205)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **Refactor**
- Simplified model component creation by introducing a new function for
better code clarity and reusability.
- Updated model-building functions to utilize the new component creation
logic, enhancing maintainability.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Oct 14, 2024
1 parent 8279cca commit a1f8672
Showing 1 changed file with 35 additions and 61 deletions.
96 changes: 35 additions & 61 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,29 @@
)


def _get_standard_model_components(model_params, ntypes):
# descriptor
model_params["descriptor"]["ntypes"] = ntypes
model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"])
descriptor = BaseDescriptor(**model_params["descriptor"])
# fitting
fitting_net = model_params.get("fitting_net", {})
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["type_map"] = copy.deepcopy(model_params["type_map"])
fitting_net["mixed_types"] = descriptor.mixed_types()
if fitting_net["type"] in ["dipole", "polar"]:
fitting_net["embedding_width"] = descriptor.get_dim_emb()
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
fitting_net["out_dim"] = descriptor.get_dim_emb()
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
fitting = BaseFitting(**fitting_net)
return descriptor, fitting, fitting_net["type"]


def get_spin_model(model_params):
model_params = copy.deepcopy(model_params)
if not model_params["spin"]["use_spin"] or isinstance(
Expand Down Expand Up @@ -117,25 +140,9 @@ def get_linear_model(model_params):
if "descriptor" in sub_model_params:
# descriptor
sub_model_params["descriptor"]["ntypes"] = ntypes
sub_model_params["descriptor"]["type_map"] = copy.deepcopy(
model_params["type_map"]
descriptor, fitting, _ = _get_standard_model_components(
sub_model_params, ntypes
)
descriptor = BaseDescriptor(**sub_model_params["descriptor"])
# fitting
fitting_net = sub_model_params.get("fitting_net", {})
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["type_map"] = copy.deepcopy(model_params["type_map"])
fitting_net["mixed_types"] = descriptor.mixed_types()
if fitting_net["type"] in ["dipole", "polar"]:
fitting_net["embedding_width"] = descriptor.get_dim_emb()
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
fitting_net["out_dim"] = descriptor.get_dim_emb()
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
fitting = BaseFitting(**fitting_net)
list_of_models.append(
DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"])
)
Expand Down Expand Up @@ -167,24 +174,7 @@ def get_linear_model(model_params):
def get_zbl_model(model_params):
model_params = copy.deepcopy(model_params)
ntypes = len(model_params["type_map"])
# descriptor
model_params["descriptor"]["ntypes"] = ntypes
model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"])
descriptor = BaseDescriptor(**model_params["descriptor"])
# fitting
fitting_net = model_params.get("fitting_net", None)
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["type_map"] = copy.deepcopy(model_params["type_map"])
fitting_net["mixed_types"] = descriptor.mixed_types()
fitting_net["embedding_width"] = descriptor.get_dim_out()
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
fitting_net["out_dim"] = descriptor.get_dim_emb()
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
fitting = BaseFitting(**fitting_net)
descriptor, fitting, _ = _get_standard_model_components(model_params, ntypes)
dp_model = DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"])
# pairtab
filepath = model_params["use_srtab"]
Expand Down Expand Up @@ -246,44 +236,28 @@ def get_standard_model(model_params):
model_params_old = model_params
model_params = copy.deepcopy(model_params)
ntypes = len(model_params["type_map"])
# descriptor
model_params["descriptor"]["ntypes"] = ntypes
model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"])
descriptor = BaseDescriptor(**model_params["descriptor"])
# fitting
fitting_net = model_params.get("fitting_net", {})
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["type_map"] = copy.deepcopy(model_params["type_map"])
fitting_net["mixed_types"] = descriptor.mixed_types()
if fitting_net["type"] in ["dipole", "polar"]:
fitting_net["embedding_width"] = descriptor.get_dim_emb()
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
fitting_net["out_dim"] = descriptor.get_dim_emb()
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
fitting = BaseFitting(**fitting_net)
descriptor, fitting, fitting_net_type = _get_standard_model_components(
model_params, ntypes
)
atom_exclude_types = model_params.get("atom_exclude_types", [])
pair_exclude_types = model_params.get("pair_exclude_types", [])
preset_out_bias = model_params.get("preset_out_bias")
preset_out_bias = _convert_preset_out_bias_to_array(
preset_out_bias, model_params["type_map"]
)

if fitting_net["type"] == "dipole":
if fitting_net_type == "dipole":
modelcls = DipoleModel
elif fitting_net["type"] == "polar":
elif fitting_net_type == "polar":
modelcls = PolarModel
elif fitting_net["type"] == "dos":
elif fitting_net_type == "dos":
modelcls = DOSModel
elif fitting_net["type"] in ["ener", "direct_force_ener"]:
elif fitting_net_type in ["ener", "direct_force_ener"]:
modelcls = EnergyModel
elif fitting_net["type"] == "property":
elif fitting_net_type == "property":
modelcls = PropertyModel
else:
raise RuntimeError(f"Unknown fitting type: {fitting_net['type']}")
raise RuntimeError(f"Unknown fitting type: {fitting_net_type}")

model = modelcls(
descriptor=descriptor,
Expand Down

0 comments on commit a1f8672

Please sign in to comment.