diff --git a/deepmd/dpmodel/__init__.py b/deepmd/dpmodel/__init__.py index 6a7bdb3585..6f83f849a3 100644 --- a/deepmd/dpmodel/__init__.py +++ b/deepmd/dpmodel/__init__.py @@ -5,7 +5,7 @@ NativeOP, ) from .model import ( - DPModel, + DPModelCommon, ) from .output_def import ( FittingOutputDef, @@ -19,7 +19,7 @@ ) __all__ = [ - "DPModel", + "DPModelCommon", "PRECISION_DICT", "DEFAULT_PRECISION", "NativeOP", diff --git a/deepmd/dpmodel/model/__init__.py b/deepmd/dpmodel/model/__init__.py index c1ff15ab0d..7cd68dea60 100644 --- a/deepmd/dpmodel/model/__init__.py +++ b/deepmd/dpmodel/model/__init__.py @@ -13,7 +13,7 @@ """ from .dp_model import ( - DPModel, + DPModelCommon, ) from .make_model import ( make_model, @@ -23,7 +23,7 @@ ) __all__ = [ - "DPModel", + "DPModelCommon", "SpinModel", "make_model", ] diff --git a/deepmd/dpmodel/model/base_model.py b/deepmd/dpmodel/model/base_model.py index 5169d1b5fe..de69afcf6c 100644 --- a/deepmd/dpmodel/model/base_model.py +++ b/deepmd/dpmodel/model/base_model.py @@ -35,7 +35,11 @@ class BaseBaseModel(ABC, PluginVariant, make_plugin_registry("model")): def __new__(cls, *args, **kwargs): if inspect.isabstract(cls): - cls = cls.get_class_by_type(kwargs.get("type", "standard")) + # getting model type based on fitting type + model_type = kwargs.get("type", "standard") + if model_type == "standard": + model_type = kwargs.get("fitting", {}).get("type", "ener") + cls = cls.get_class_by_type(model_type) return super().__new__(cls) @abstractmethod @@ -118,7 +122,10 @@ def deserialize(cls, data: dict) -> "BaseBaseModel": The deserialized model """ if inspect.isabstract(cls): - return cls.get_class_by_type(data["type"]).deserialize(data) + model_type = data.get("type", "standard") + if model_type == "standard": + model_type = data.get("fitting", {}).get("type", "ener") + return cls.get_class_by_type(model_type).deserialize(data) raise NotImplementedError("Not implemented in class %s" % cls.__name__) model_def_script: str @@ -151,7 +158,11 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): local_jdata : dict The local data refer to the current class """ - cls = cls.get_class_by_type(local_jdata.get("type", "standard")) + # getting model type based on fitting type + model_type = local_jdata.get("type", "standard") + if model_type == "standard": + model_type = local_jdata.get("fitting", {}).get("type", "ener") + cls = cls.get_class_by_type(model_type) return cls.update_sel(global_jdata, local_jdata) return BaseBaseModel diff --git a/deepmd/dpmodel/model/dp_model.py b/deepmd/dpmodel/model/dp_model.py index 8d84c435b4..37cb426ab7 100644 --- a/deepmd/dpmodel/model/dp_model.py +++ b/deepmd/dpmodel/model/dp_model.py @@ -1,24 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from deepmd.dpmodel.atomic_model import ( - DPAtomicModel, -) from deepmd.dpmodel.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.dpmodel.model.base_model import ( - BaseModel, -) - -from .make_model import ( - make_model, -) # use "class" to resolve "Variable not allowed in type expression" -@BaseModel.register("standard") -class DPModel(make_model(DPAtomicModel)): +class DPModelCommon: @classmethod def update_sel(cls, global_jdata: dict, local_jdata: dict): """Update the selection and perform neighbor statistics. diff --git a/deepmd/dpmodel/model/ener_model.py b/deepmd/dpmodel/model/ener_model.py new file mode 100644 index 0000000000..5f21681830 --- /dev/null +++ b/deepmd/dpmodel/model/ener_model.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_model import ( + make_model, +) + +DPEnergyModel_ = make_model(DPAtomicModel) + + +@BaseModel.register("ener") +class EnergyModel(DPModelCommon, DPEnergyModel_): + def __init__( + self, + *args, + **kwargs, + ): + DPModelCommon.__init__(self) + DPEnergyModel_.__init__(self, *args, **kwargs) diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 3fdf5b802b..0df6e94f05 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -5,8 +5,8 @@ from deepmd.dpmodel.fitting.ener_fitting import ( EnergyFittingNet, ) -from deepmd.dpmodel.model.dp_model import ( - DPModel, +from deepmd.dpmodel.model.ener_model import ( + EnergyModel, ) from deepmd.dpmodel.model.spin_model import ( SpinModel, @@ -16,8 +16,8 @@ ) -def get_standard_model(data: dict) -> DPModel: - """Get a standard DPModel from a dictionary. +def get_standard_model(data: dict) -> EnergyModel: + """Get a EnergyModel from a dictionary. Parameters ---------- @@ -41,7 +41,7 @@ def get_standard_model(data: dict) -> DPModel: ) else: raise ValueError(f"Unknown fitting type {fitting_type}") - return DPModel( + return EnergyModel( descriptor=descriptor, fitting=fitting, type_map=data["type_map"], diff --git a/deepmd/dpmodel/model/spin_model.py b/deepmd/dpmodel/model/spin_model.py index 5b31b64fdf..90e2bb3fb4 100644 --- a/deepmd/dpmodel/model/spin_model.py +++ b/deepmd/dpmodel/model/spin_model.py @@ -7,8 +7,11 @@ import numpy as np -from deepmd.dpmodel.model.dp_model import ( - DPModel, +from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.dpmodel.model.make_model import ( + make_model, ) from deepmd.utils.spin import ( Spin, @@ -259,7 +262,9 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "SpinModel": - backbone_model_obj = DPModel.deserialize(data["backbone_model"]) + backbone_model_obj = make_model(DPAtomicModel).deserialize( + data["backbone_model"] + ) spin = Spin.deserialize(data["spin"]) return cls( backbone_model=backbone_model_obj, diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index a747f28556..3e94449057 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -17,9 +17,18 @@ from .base_atomic_model import ( BaseAtomicModel, ) +from .dipole_atomic_model import ( + DPDipoleAtomicModel, +) +from .dos_atomic_model import ( + DPDOSAtomicModel, +) from .dp_atomic_model import ( DPAtomicModel, ) +from .energy_atomic_model import ( + DPEnergyAtomicModel, +) from .linear_atomic_model import ( DPZBLLinearEnergyAtomicModel, LinearEnergyAtomicModel, @@ -27,11 +36,18 @@ from .pairtab_atomic_model import ( PairTabAtomicModel, ) +from .polar_atomic_model import ( + DPPolarAtomicModel, +) __all__ = [ "BaseAtomicModel", "DPAtomicModel", + "DPDOSAtomicModel", + "DPEnergyAtomicModel", "PairTabAtomicModel", "LinearEnergyAtomicModel", + "DPPolarAtomicModel", + "DPDipoleAtomicModel", "DPZBLLinearEnergyAtomicModel", ] diff --git a/deepmd/pt/model/atomic_model/dipole_atomic_model.py b/deepmd/pt/model/atomic_model/dipole_atomic_model.py new file mode 100644 index 0000000000..1723a30f2d --- /dev/null +++ b/deepmd/pt/model/atomic_model/dipole_atomic_model.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, +) + +import torch + +from deepmd.pt.model.task.dipole import ( + DipoleFittingNet, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + + +class DPDipoleAtomicModel(DPAtomicModel): + def __init__(self, descriptor, fitting, type_map, **kwargs): + assert isinstance(fitting, DipoleFittingNet) + super().__init__(descriptor, fitting, type_map, **kwargs) + + def apply_out_stat( + self, + ret: Dict[str, torch.Tensor], + atype: torch.Tensor, + ): + # dipole not applying bias + return ret diff --git a/deepmd/pt/model/atomic_model/dos_atomic_model.py b/deepmd/pt/model/atomic_model/dos_atomic_model.py new file mode 100644 index 0000000000..5e399f2aff --- /dev/null +++ b/deepmd/pt/model/atomic_model/dos_atomic_model.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.pt.model.task.dos import ( + DOSFittingNet, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + + +class DPDOSAtomicModel(DPAtomicModel): + def __init__(self, descriptor, fitting, type_map, **kwargs): + assert isinstance(fitting, DOSFittingNet) + super().__init__(descriptor, fitting, type_map, **kwargs) diff --git a/deepmd/pt/model/atomic_model/energy_atomic_model.py b/deepmd/pt/model/atomic_model/energy_atomic_model.py new file mode 100644 index 0000000000..7cedaa1ab3 --- /dev/null +++ b/deepmd/pt/model/atomic_model/energy_atomic_model.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.pt.model.task.ener import ( + EnergyFittingNet, + EnergyFittingNetDirect, + InvarFitting, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + + +class DPEnergyAtomicModel(DPAtomicModel): + def __init__(self, descriptor, fitting, type_map, **kwargs): + assert ( + isinstance(fitting, EnergyFittingNet) + or isinstance(fitting, EnergyFittingNetDirect) + or isinstance(fitting, InvarFitting) + ) + super().__init__(descriptor, fitting, type_map, **kwargs) diff --git a/deepmd/pt/model/atomic_model/polar_atomic_model.py b/deepmd/pt/model/atomic_model/polar_atomic_model.py new file mode 100644 index 0000000000..3eb4136b6e --- /dev/null +++ b/deepmd/pt/model/atomic_model/polar_atomic_model.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, +) + +import torch + +from deepmd.pt.model.task.polarizability import ( + PolarFittingNet, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + + +class DPPolarAtomicModel(DPAtomicModel): + def __init__(self, descriptor, fitting, type_map, **kwargs): + assert isinstance(fitting, PolarFittingNet) + super().__init__(descriptor, fitting, type_map, **kwargs) + + def apply_out_stat( + self, + ret: Dict[str, torch.Tensor], + atype: torch.Tensor, + ): + # TODO: migrate bias + return ret diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 1675215d7b..260ce92099 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -30,8 +30,14 @@ Spin, ) +from .dipole_model import ( + DipoleModel, +) +from .dos_model import ( + DOSModel, +) from .dp_model import ( - DPModel, + DPModelCommon, ) from .dp_zbl_model import ( DPZBLModel, @@ -51,6 +57,9 @@ from .model import ( BaseModel, ) +from .polar_model import ( + PolarModel, +) from .spin_model import ( SpinEnergyModel, SpinModel, @@ -160,7 +169,18 @@ def get_standard_model(model_params): atom_exclude_types = model_params.get("atom_exclude_types", []) pair_exclude_types = model_params.get("pair_exclude_types", []) - model = DPModel( + if fitting_net["type"] == "dipole": + modelcls = DipoleModel + elif fitting_net["type"] == "polar": + modelcls = PolarModel + elif fitting_net["type"] == "dos": + modelcls = DOSModel + elif fitting_net["type"] in ["ener", "direct_force_ener"]: + modelcls = EnergyModel + else: + raise RuntimeError(f"Unknown fitting type: {fitting_net['type']}") + + model = modelcls( descriptor=descriptor, fitting=fitting, type_map=model_params["type_map"], @@ -183,7 +203,7 @@ def get_model(model_params): __all__ = [ "BaseModel", "get_model", - "DPModel", + "DPModelCommon", "EnergyModel", "FrozenModel", "SpinModel", diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index 45b120771b..b73976432b 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -6,12 +6,25 @@ import torch +from deepmd.pt.model.atomic_model import ( + DPDipoleAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + from .dp_model import ( - DPModel, + DPModelCommon, ) +from .make_model import ( + make_model, +) + +DPDOSModel_ = make_model(DPDipoleAtomicModel) -class DipoleModel(DPModel): +@BaseModel.register("dipole") +class DipoleModel(DPModelCommon, DPDOSModel_): model_type = "dipole" def __init__( @@ -19,7 +32,8 @@ def __init__( *args, **kwargs, ): - super().__init__(*args, **kwargs) + DPModelCommon.__init__(self) + DPDOSModel_.__init__(self, *args, **kwargs) def forward( self, diff --git a/deepmd/pt/model/model/dos_model.py b/deepmd/pt/model/model/dos_model.py index e043700bee..24095002d1 100644 --- a/deepmd/pt/model/model/dos_model.py +++ b/deepmd/pt/model/model/dos_model.py @@ -6,12 +6,25 @@ import torch +from deepmd.pt.model.atomic_model import ( + DPDOSAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + from .dp_model import ( - DPModel, + DPModelCommon, ) +from .make_model import ( + make_model, +) + +DPDOSModel_ = make_model(DPDOSAtomicModel) -class DOSModel(DPModel): +@BaseModel.register("dos") +class DOSModel(DPModelCommon, DPDOSModel_): model_type = "dos" def __init__( @@ -19,7 +32,8 @@ def __init__( *args, **kwargs, ): - super().__init__(*args, **kwargs) + DPModelCommon.__init__(self) + DPDOSModel_.__init__(self, *args, **kwargs) def forward( self, diff --git a/deepmd/pt/model/model/dp_model.py b/deepmd/pt/model/model/dp_model.py index d7b3c4f4e2..fab1ff580f 100644 --- a/deepmd/pt/model/model/dp_model.py +++ b/deepmd/pt/model/model/dp_model.py @@ -1,83 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Dict, - Optional, -) - -import torch - -from deepmd.pt.model.atomic_model import ( - DPAtomicModel, -) from deepmd.pt.model.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.pt.model.model.model import ( - BaseModel, -) -from deepmd.pt.model.task.dipole import ( - DipoleFittingNet, -) -from deepmd.pt.model.task.dos import ( - DOSFittingNet, -) -from deepmd.pt.model.task.ener import ( - EnergyFittingNet, - EnergyFittingNetDirect, -) -from deepmd.pt.model.task.polarizability import ( - PolarFittingNet, -) -from .make_model import ( - make_model, -) - -@BaseModel.register("standard") -class DPModel(make_model(DPAtomicModel)): - def __new__( - cls, - descriptor=None, - fitting=None, - *args, - # disallow positional atomic_model_ - atomic_model_: Optional[DPAtomicModel] = None, - **kwargs, - ): - from deepmd.pt.model.model.dipole_model import ( - DipoleModel, - ) - from deepmd.pt.model.model.dos_model import ( - DOSModel, - ) - from deepmd.pt.model.model.ener_model import ( - EnergyModel, - ) - from deepmd.pt.model.model.polar_model import ( - PolarModel, - ) - - if atomic_model_ is not None: - fitting = atomic_model_.fitting_net - else: - assert fitting is not None, "fitting network is not provided" - - # according to the fitting network to decide the type of the model - if cls is DPModel: - # map fitting to model - if isinstance(fitting, EnergyFittingNet) or isinstance( - fitting, EnergyFittingNetDirect - ): - cls = EnergyModel - elif isinstance(fitting, DipoleFittingNet): - cls = DipoleModel - elif isinstance(fitting, PolarFittingNet): - cls = PolarModel - elif isinstance(fitting, DOSFittingNet): - cls = DOSModel - # else: unknown fitting type, fall back to DPModel - return super().__new__(cls) +class DPModelCommon: + """A base class to implement common methods for all the Models.""" @classmethod def update_sel(cls, global_jdata: dict, local_jdata: dict): @@ -103,22 +31,3 @@ def get_fitting_net(self): def get_descriptor(self): """Get the descriptor.""" return self.atomic_model.descriptor - - def forward( - self, - coord, - atype, - box: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, - do_atomic_virial: bool = False, - ) -> Dict[str, torch.Tensor]: - # directly call the forward_common method when no specific transform rule - return self.forward_common( - coord, - atype, - box, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, - ) diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index bbc82b8d77..f18e1d097f 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -6,9 +6,6 @@ import torch -from deepmd.dpmodel.model.dp_model import ( - DPModel, -) from deepmd.pt.model.atomic_model import ( DPZBLLinearEnergyAtomicModel, ) @@ -16,6 +13,9 @@ BaseModel, ) +from .dp_model import ( + DPModelCommon, +) from .make_model import ( make_model, ) @@ -116,7 +116,7 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): The local data refer to the current class """ local_jdata_cpy = local_jdata.copy() - local_jdata_cpy["dpmodel"] = DPModel.update_sel( + local_jdata_cpy["dpmodel"] = DPModelCommon.update_sel( global_jdata, local_jdata["dpmodel"] ) return local_jdata_cpy diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 5217293623..4a0eb49945 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -6,12 +6,25 @@ import torch +from deepmd.pt.model.atomic_model import ( + DPEnergyAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + from .dp_model import ( - DPModel, + DPModelCommon, ) +from .make_model import ( + make_model, +) + +DPEnergyModel_ = make_model(DPEnergyAtomicModel) -class EnergyModel(DPModel): +@BaseModel.register("ener") +class EnergyModel(DPModelCommon, DPEnergyModel_): model_type = "ener" def __init__( @@ -19,7 +32,8 @@ def __init__( *args, **kwargs, ): - super().__init__(*args, **kwargs) + DPModelCommon.__init__(self) + DPEnergyModel_.__init__(self, *args, **kwargs) def forward( self, diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 25a8ec9201..386b5e14f9 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -527,4 +527,23 @@ def mixed_types(self) -> bool: """ return self.atomic_model.mixed_types() + def forward( + self, + coord, + atype, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + # directly call the forward_common method when no specific transform rule + return self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + return CM diff --git a/deepmd/pt/model/model/polar_model.py b/deepmd/pt/model/model/polar_model.py index 403058aa47..867afefc27 100644 --- a/deepmd/pt/model/model/polar_model.py +++ b/deepmd/pt/model/model/polar_model.py @@ -6,12 +6,25 @@ import torch +from deepmd.pt.model.atomic_model import ( + DPPolarAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + from .dp_model import ( - DPModel, + DPModelCommon, ) +from .make_model import ( + make_model, +) + +DPDOSModel_ = make_model(DPPolarAtomicModel) -class PolarModel(DPModel): +@BaseModel.register("polar") +class PolarModel(DPModelCommon, DPDOSModel_): model_type = "polar" def __init__( @@ -19,7 +32,8 @@ def __init__( *args, **kwargs, ): - super().__init__(*args, **kwargs) + DPModelCommon.__init__(self) + DPDOSModel_.__init__(self, *args, **kwargs) def forward( self, diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index df2f48e2e4..4c71344c8a 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -8,6 +8,9 @@ import torch +from deepmd.pt.model.atomic_model import ( + DPAtomicModel, +) from deepmd.pt.utils.utils import ( to_torch_tensor, ) @@ -18,8 +21,8 @@ Spin, ) -from .dp_model import ( - DPModel, +from .make_model import ( + make_model, ) @@ -474,7 +477,9 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "SpinModel": - backbone_model_obj = DPModel.deserialize(data["backbone_model"]) + backbone_model_obj = make_model(DPAtomicModel).deserialize( + data["backbone_model"] + ) spin = Spin.deserialize(data["spin"]) return cls( backbone_model=backbone_model_obj, diff --git a/deepmd/tf/model/dos.py b/deepmd/tf/model/dos.py index 265026b60a..2d244f5733 100644 --- a/deepmd/tf/model/dos.py +++ b/deepmd/tf/model/dos.py @@ -23,6 +23,7 @@ ) +@StandardModel.register("dos") class DOSModel(StandardModel): """DOS model. diff --git a/deepmd/tf/model/ener.py b/deepmd/tf/model/ener.py index a493fe0517..4a0195b334 100644 --- a/deepmd/tf/model/ener.py +++ b/deepmd/tf/model/ener.py @@ -32,6 +32,7 @@ ) +@StandardModel.register("ener") class EnerModel(StandardModel): """Energy model. diff --git a/deepmd/tf/model/tensor.py b/deepmd/tf/model/tensor.py index b232f40b13..ff20b771a7 100644 --- a/deepmd/tf/model/tensor.py +++ b/deepmd/tf/model/tensor.py @@ -242,11 +242,13 @@ def __init__(self, *args, **kwargs) -> None: TensorModel.__init__(self, "wfc", *args, **kwargs) +@StandardModel.register("dipole") class DipoleModel(TensorModel): def __init__(self, *args, **kwargs) -> None: TensorModel.__init__(self, "dipole", *args, **kwargs) +@StandardModel.register("polar") class PolarModel(TensorModel): def __init__(self, *args, **kwargs) -> None: TensorModel.__init__(self, "polar", *args, **kwargs) diff --git a/source/tests/common/dpmodel/test_dp_model.py b/source/tests/common/dpmodel/test_dp_model.py index 9121c7cd07..67ab6a7c32 100644 --- a/source/tests/common/dpmodel/test_dp_model.py +++ b/source/tests/common/dpmodel/test_dp_model.py @@ -9,8 +9,8 @@ from deepmd.dpmodel.fitting import ( InvarFitting, ) -from deepmd.dpmodel.model import ( - DPModel, +from deepmd.dpmodel.model.ener_model import ( + EnergyModel, ) from .case_single_frame_with_nlist import ( @@ -40,8 +40,8 @@ def test_self_consistency( mixed_types=ds.mixed_types(), ) type_map = ["foo", "bar"] - md0 = DPModel(ds, ft, type_map=type_map) - md1 = DPModel.deserialize(md0.serialize()) + md0 = EnergyModel(ds, ft, type_map=type_map) + md1 = EnergyModel.deserialize(md0.serialize()) ret0 = md0.call_lower(self.coord_ext, self.atype_ext, self.nlist) ret1 = md1.call_lower(self.coord_ext, self.atype_ext, self.nlist) @@ -70,7 +70,7 @@ def test_prec_consistency(self): fparam = rng.normal(size=[self.nf, nfp]) aparam = rng.normal(size=[self.nf, nloc, nap]) - md1 = DPModel(ds, ft, type_map=type_map) + md1 = EnergyModel(ds, ft, type_map=type_map) args64 = [self.coord_ext, self.atype_ext, self.nlist] args64[0] = args64[0].astype(np.float64) @@ -122,7 +122,7 @@ def test_prec_consistency(self): fparam = rng.normal(size=[self.nf, nfp]) aparam = rng.normal(size=[self.nf, nloc, nap]) - md1 = DPModel(ds, ft, type_map=type_map) + md1 = EnergyModel(ds, ft, type_map=type_map) args64 = [self.coord, self.atype, self.cell] args64[0] = args64[0].astype(np.float64) diff --git a/source/tests/common/dpmodel/test_nlist.py b/source/tests/common/dpmodel/test_nlist.py index ee8a7139e7..404232013b 100644 --- a/source/tests/common/dpmodel/test_nlist.py +++ b/source/tests/common/dpmodel/test_nlist.py @@ -9,8 +9,8 @@ from deepmd.dpmodel.fitting import ( InvarFitting, ) -from deepmd.dpmodel.model import ( - DPModel, +from deepmd.dpmodel.model.ener_model import ( + EnergyModel, ) from deepmd.dpmodel.utils import ( build_multiple_neighbor_list, @@ -65,7 +65,7 @@ def setUp(self): mixed_types=ds.mixed_types(), ) type_map = ["foo", "bar"] - self.md = DPModel(ds, ft, type_map=type_map) + self.md = EnergyModel(ds, ft, type_map=type_map) def test_nlist_eq(self): # n_nnei == nnei diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index da5033a3b6..c8ff9e4dcf 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -7,7 +7,7 @@ import numpy as np -from deepmd.dpmodel.model.dp_model import DPModel as EnergyModelDP +from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP from deepmd.dpmodel.model.model import get_model as get_model_dp from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, diff --git a/source/tests/pt/model/test_dp_model.py b/source/tests/pt/model/test_dp_model.py index 7470cf96d0..3d59a33ca2 100644 --- a/source/tests/pt/model/test_dp_model.py +++ b/source/tests/pt/model/test_dp_model.py @@ -4,18 +4,17 @@ import numpy as np import torch -from deepmd.dpmodel import DPModel as DPDPModel from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA -from deepmd.dpmodel.fitting import InvarFitting as DPInvarFitting +from deepmd.dpmodel.fitting import EnergyFittingNet as DPEnergyFittingNet +from deepmd.dpmodel.model.ener_model import EnergyModel as DPEnergyModel from deepmd.pt.model.descriptor.se_a import ( DescrptSeA, ) from deepmd.pt.model.model import ( - DPModel, EnergyModel, ) from deepmd.pt.model.task.ener import ( - InvarFitting, + EnergyFittingNet, ) from deepmd.pt.utils import ( env, @@ -49,16 +48,14 @@ def test_self_consistency(self): self.rcut_smth, self.sel, ).to(env.DEVICE) - ft = InvarFitting( - "energy", + ft = EnergyFittingNet( self.nt, ds.get_dim_out(), - 1, mixed_types=ds.mixed_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - md0 = DPModel(ds, ft, type_map=type_map).to(env.DEVICE) - md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE) + md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE) args = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] ret0 = md0.forward_common(*args) ret1 = md1.forward_common(*args) @@ -122,18 +119,16 @@ def test_dp_consistency(self): self.rcut_smth, self.sel, ) - ft = DPInvarFitting( - "energy", + ft = DPEnergyFittingNet( self.nt, ds.get_dim_out(), - 1, mixed_types=ds.mixed_types(), numb_fparam=nfp, numb_aparam=nap, ) type_map = ["foo", "bar"] - md0 = DPDPModel(ds, ft, type_map=type_map) - md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + md0 = DPEnergyModel(ds, ft, type_map=type_map) + md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE) rng = np.random.default_rng() fparam = rng.normal(size=[self.nf, nfp]) @@ -163,18 +158,16 @@ def test_dp_consistency_nopbc(self): self.rcut_smth, self.sel, ) - ft = DPInvarFitting( - "energy", + ft = DPEnergyFittingNet( self.nt, ds.get_dim_out(), - 1, mixed_types=ds.mixed_types(), numb_fparam=nfp, numb_aparam=nap, ) type_map = ["foo", "bar"] - md0 = DPDPModel(ds, ft, type_map=type_map) - md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + md0 = DPEnergyModel(ds, ft, type_map=type_map) + md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE) rng = np.random.default_rng() fparam = rng.normal(size=[self.nf, nfp]) @@ -204,11 +197,9 @@ def test_prec_consistency(self): self.rcut_smth, self.sel, ) - ft = DPInvarFitting( - "energy", + ft = DPEnergyFittingNet( self.nt, ds.get_dim_out(), - 1, mixed_types=ds.mixed_types(), ) nfp, nap = 2, 3 @@ -216,8 +207,8 @@ def test_prec_consistency(self): fparam = rng.normal(size=[self.nf, nfp]) aparam = rng.normal(size=[self.nf, nloc, nap]) - md0 = DPDPModel(ds, ft, type_map=type_map) - md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + md0 = DPEnergyModel(ds, ft, type_map=type_map) + md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE) args64 = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] args64[0] = args64[0].to(torch.float64) @@ -259,16 +250,14 @@ def test_self_consistency(self): self.rcut_smth, self.sel, ).to(env.DEVICE) - ft = InvarFitting( - "energy", + ft = EnergyFittingNet( self.nt, ds.get_dim_out(), - 1, mixed_types=ds.mixed_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - md0 = DPModel(ds, ft, type_map=type_map).to(env.DEVICE) - md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE) + md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE) args = [ to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] ] @@ -310,16 +299,14 @@ def test_dp_consistency(self): self.rcut_smth, self.sel, ) - ft = DPInvarFitting( - "energy", + ft = DPEnergyFittingNet( self.nt, ds.get_dim_out(), - 1, mixed_types=ds.mixed_types(), ) type_map = ["foo", "bar"] - md0 = DPDPModel(ds, ft, type_map=type_map) - md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + md0 = DPEnergyModel(ds, ft, type_map=type_map) + md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE) args0 = [self.coord_ext, self.atype_ext, self.nlist] args1 = [ to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] @@ -345,11 +332,9 @@ def test_prec_consistency(self): self.rcut_smth, self.sel, ) - ft = DPInvarFitting( - "energy", + ft = DPEnergyFittingNet( self.nt, ds.get_dim_out(), - 1, mixed_types=ds.mixed_types(), ) nfp, nap = 2, 3 @@ -357,8 +342,8 @@ def test_prec_consistency(self): fparam = rng.normal(size=[self.nf, nfp]) aparam = rng.normal(size=[self.nf, nloc, nap]) - md0 = DPDPModel(ds, ft, type_map=type_map) - md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + md0 = DPEnergyModel(ds, ft, type_map=type_map) + md1 = EnergyModel.deserialize(md0.serialize()).to(env.DEVICE) args64 = [ to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] @@ -397,15 +382,13 @@ def test_jit(self): self.rcut_smth, self.sel, ).to(env.DEVICE) - ft = InvarFitting( - "energy", + ft = EnergyFittingNet( self.nt, ds.get_dim_out(), - 1, mixed_types=ds.mixed_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - md0 = DPModel(ds, ft, type_map=type_map).to(env.DEVICE) + md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE) md0 = torch.jit.script(md0) md0.get_rcut() md0.get_type_map() @@ -447,15 +430,13 @@ def setUp(self): self.rcut_smth, self.sel, ).to(env.DEVICE) - ft = InvarFitting( - "energy", + ft = EnergyFittingNet( self.nt, ds.get_dim_out(), - 1, mixed_types=ds.mixed_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] - self.md = DPModel(ds, ft, type_map=type_map).to(env.DEVICE) + self.md = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE) def test_nlist_eq(self): # n_nnei == nnei @@ -520,11 +501,9 @@ def test_self_consistency(self): self.rcut_smth, self.sel, ).to(env.DEVICE) - ft = InvarFitting( - "energy", + ft = EnergyFittingNet( self.nt, ds.get_dim_out(), - 1, mixed_types=ds.mixed_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] @@ -590,11 +569,9 @@ def test_self_consistency(self): self.rcut_smth, self.sel, ).to(env.DEVICE) - ft = InvarFitting( - "energy", + ft = EnergyFittingNet( self.nt, ds.get_dim_out(), - 1, mixed_types=ds.mixed_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] @@ -640,11 +617,9 @@ def test_jit(self): self.rcut_smth, self.sel, ).to(env.DEVICE) - ft = InvarFitting( - "energy", + ft = EnergyFittingNet( self.nt, ds.get_dim_out(), - 1, mixed_types=ds.mixed_types(), ).to(env.DEVICE) type_map = ["foo", "bar"] diff --git a/source/tests/pt/model/test_make_hessian_model.py b/source/tests/pt/model/test_make_hessian_model.py index 7d9ae2b810..b94e83bafc 100644 --- a/source/tests/pt/model/test_make_hessian_model.py +++ b/source/tests/pt/model/test_make_hessian_model.py @@ -11,11 +11,9 @@ DescrptSeA, ) from deepmd.pt.model.model import ( + EnergyModel, make_hessian_model, ) -from deepmd.pt.model.model.dp_model import ( - DPModel, -) from deepmd.pt.model.task.ener import ( InvarFitting, ) @@ -159,10 +157,10 @@ def setUp(self): neuron=[4, 4, 4], ).to(env.DEVICE) type_map = ["foo", "bar"] - self.model_hess = make_hessian_model(DPModel)(ds, ft0, type_map=type_map).to( - env.DEVICE - ) - self.model_valu = DPModel.deserialize(self.model_hess.serialize()) + self.model_hess = make_hessian_model(EnergyModel)( + ds, ft0, type_map=type_map + ).to(env.DEVICE) + self.model_valu = EnergyModel.deserialize(self.model_hess.serialize()) self.model_hess.requires_hessian("energy") def test_output_def(self):