diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 1a823e369e..178b286e79 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -130,6 +130,8 @@ def forward_atomic( def serialize(self) -> dict: return { + "@class": "Model", + "type": "standard", "type_map": self.type_map, "descriptor": self.descriptor.serialize(), "fitting": self.fitting.serialize(), @@ -138,6 +140,8 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "DPAtomicModel": data = copy.deepcopy(data) + data.pop("@class") + data.pop("type") descriptor_obj = BaseDescriptor.deserialize(data["descriptor"]) fitting_obj = BaseFitting.deserialize(data["fitting"]) obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"]) diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 520ad9185e..e1130eaf45 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy import sys from abc import ( abstractmethod, @@ -182,12 +183,17 @@ def fitting_output_def(self) -> FittingOutputDef: @staticmethod def serialize(models) -> dict: return { + "@class": "Model", + "type": "linear", "models": [model.serialize() for model in models], "model_name": [model.__class__.__name__ for model in models], } @staticmethod def deserialize(data) -> List[BaseAtomicModel]: + data = copy.deepcopy(data) + data.pop("@class") + data.pop("type") model_names = data["model_name"] models = [ getattr(sys.modules[__name__], name).deserialize(model) @@ -263,6 +269,8 @@ def __init__( def serialize(self) -> dict: return { + "@class": "Model", + "type": "zbl", "models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]), "sw_rmin": self.sw_rmin, "sw_rmax": self.sw_rmax, @@ -271,6 +279,9 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "DPZBLLinearAtomicModel": + data = copy.deepcopy(data) + data.pop("@class") + data.pop("type") sw_rmin = data["sw_rmin"] sw_rmax = data["sw_rmax"] smin_alpha = data["smin_alpha"] diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index b6c6b8460f..d4186c990d 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -160,6 +160,10 @@ def do_grad_(self, var_name: str, base: str) -> bool: return self.fitting_output_def()[var_name].c_differentiable return self.fitting_output_def()[var_name].r_differentiable + def get_model_def_script(self) -> str: + # TODO: implement this method; saved to model + raise NotImplementedError + setattr(BAM, fwd_method_name, BAM.fwd) delattr(BAM, "fwd") diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index 34f6514986..dc3dfaf2ed 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy from typing import ( Dict, List, @@ -105,10 +106,19 @@ def mixed_types(self) -> bool: return True def serialize(self) -> dict: - return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel} + return { + "@class": "Model", + "type": "pairtab", + "tab": self.tab.serialize(), + "rcut": self.rcut, + "sel": self.sel, + } @classmethod def deserialize(cls, data) -> "PairTabAtomicModel": + data = copy.deepcopy(data) + data.pop("@class") + data.pop("type") rcut = data["rcut"] sel = data["sel"] tab = PairTab.deserialize(data["tab"]) diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index 4e2349c0e8..1fd36bd7e8 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -13,8 +13,8 @@ import numpy as np -from deepmd.dpmodel.model.dp_model import ( - DPModel, +from deepmd.dpmodel.model.base_model import ( + BaseModel, ) from deepmd.dpmodel.output_def import ( ModelOutputDef, @@ -85,7 +85,7 @@ def __init__( self.model_path = model_file model_data = load_dp_model(model_file) - self.dp = DPModel.deserialize(model_data["model"]) + self.dp = BaseModel.deserialize(model_data["model"]) self.rcut = self.dp.get_rcut() self.type_map = self.dp.get_type_map() if isinstance(auto_batch_size, bool): diff --git a/deepmd/dpmodel/model/base_model.py b/deepmd/dpmodel/model/base_model.py new file mode 100644 index 0000000000..df9c926d6c --- /dev/null +++ b/deepmd/dpmodel/model/base_model.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import inspect +from abc import ( + ABC, + abstractmethod, +) +from typing import ( + Any, + List, + Type, +) + +from deepmd.utils.plugin import ( + make_plugin_registry, +) + + +def make_base_model() -> Type[object]: + class BaseBaseModel(ABC, make_plugin_registry("model")): + """Base class for final exported model that will be directly used for inference. + + The class defines some abstractmethods that will be directly called by the + inference interface. If the final model class inherits some of those methods + from other classes, `BaseModel` should be inherited as the last class to ensure + the correct method resolution order. + + This class is backend-indepedent. + + See Also + -------- + deepmd.dpmodel.model.base_model.BaseModel + BaseModel class for DPModel backend. + """ + + def __new__(cls, *args, **kwargs): + if inspect.isabstract(cls): + cls = cls.get_class_by_type(kwargs.get("type", "standard")) + return super().__new__(cls) + + @abstractmethod + def __call__(self, *args: Any, **kwds: Any) -> Any: + """Inference method. + + Parameters + ---------- + *args : Any + The input data for inference. + **kwds : Any + The input data for inference. + + Returns + ------- + Any + The output of the inference. + """ + pass + + @abstractmethod + def get_type_map(self) -> List[str]: + """Get the type map.""" + + @abstractmethod + def get_rcut(self): + """Get the cut-off radius.""" + + @abstractmethod + def get_dim_fparam(self): + """Get the number (dimension) of frame parameters of this atomic model.""" + + @abstractmethod + def get_dim_aparam(self): + """Get the number (dimension) of atomic parameters of this atomic model.""" + + @abstractmethod + def get_sel_type(self) -> List[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + + @abstractmethod + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + + If False, the shape is (nframes, nloc, ndim). + """ + + @abstractmethod + def model_output_type(self) -> str: + """Get the output type for the model.""" + + @abstractmethod + def serialize(self) -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + pass + + @classmethod + def deserialize(cls, data: dict) -> "BaseBaseModel": + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + BaseModel + The deserialized model + """ + if inspect.isabstract(cls): + return cls.get_class_by_type(data["type"]).deserialize(data) + raise NotImplementedError("Not implemented in class %s" % cls.__name__) + + model_def_script: str + + @abstractmethod + def get_model_def_script(self) -> str: + """Get the model definition script.""" + pass + + @abstractmethod + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + # for C++ interface + pass + + @abstractmethod + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + pass + + return BaseBaseModel + + +class BaseModel(make_base_model()): + """Base class for final exported model that will be directly used for inference. + + The class defines some abstractmethods that will be directly called by the + inference interface. If the final model class inherbits some of those methods + from other classes, `BaseModel` should be inherited as the last class to ensure + the correct method resolution order. + + This class is for the DPModel backend. + + See Also + -------- + deepmd.dpmodel.model.base_model.BaseBaseModel + Backend-independent BaseModel class. + """ diff --git a/deepmd/dpmodel/model/dp_model.py b/deepmd/dpmodel/model/dp_model.py index c2c40b40ba..804ce51dfd 100644 --- a/deepmd/dpmodel/model/dp_model.py +++ b/deepmd/dpmodel/model/dp_model.py @@ -2,6 +2,9 @@ from deepmd.dpmodel.atomic_model import ( DPAtomicModel, ) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) from .make_model import ( make_model, @@ -9,5 +12,6 @@ # use "class" to resolve "Variable not allowed in type expression" -class DPModel(make_model(DPAtomicModel)): +@BaseModel.register("standard") +class DPModel(make_model(DPAtomicModel), BaseModel): pass diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 08fd5898a0..23992fb4c0 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -93,6 +93,8 @@ def mixed_types(self) -> bool: def serialize(self) -> dict: return { + "@class": "Model", + "type": "standard", "type_map": self.type_map, "descriptor": self.descriptor.serialize(), "fitting": self.fitting_net.serialize(), diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 70afbcb0bc..55752463cb 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -207,6 +207,8 @@ def fitting_output_def(self) -> FittingOutputDef: @staticmethod def serialize(models) -> dict: return { + "@class": "Model", + "type": "linear", "models": [model.serialize() for model in models], "model_name": [model.__class__.__name__ for model in models], } @@ -301,6 +303,8 @@ def __init__( def serialize(self) -> dict: return { + "@class": "Model", + "type": "zbl", "models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]), "sw_rmin": self.sw_rmin, "sw_rmax": self.sw_rmax, diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index cf5a70eb88..609116cdfc 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -124,7 +124,13 @@ def mixed_types(self) -> bool: return True def serialize(self) -> dict: - return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel} + return { + "@class": "Model", + "type": "pairtab", + "tab": self.tab.serialize(), + "rcut": self.rcut, + "sel": self.sel, + } @classmethod def deserialize(cls, data) -> "PairTabAtomicModel": diff --git a/deepmd/pt/model/model/dp_model.py b/deepmd/pt/model/model/dp_model.py index 75d3820e45..5410f518d1 100644 --- a/deepmd/pt/model/model/dp_model.py +++ b/deepmd/pt/model/model/dp_model.py @@ -2,9 +2,45 @@ from deepmd.pt.model.atomic_model import ( DPAtomicModel, ) +from deepmd.pt.model.model.model import ( + BaseModel, +) +from deepmd.pt.model.task.dipole import ( + DipoleFittingNet, +) +from deepmd.pt.model.task.ener import ( + EnergyFittingNet, +) +from deepmd.pt.model.task.polarizability import ( + PolarFittingNet, +) from .make_model import ( make_model, ) -DPModel = make_model(DPAtomicModel) + +@BaseModel.register("standard") +class DPModel(make_model(DPAtomicModel), BaseModel): + def __new__(cls, descriptor, fitting, *args, **kwargs): + from deepmd.pt.model.model.dipole_model import ( + DipoleModel, + ) + from deepmd.pt.model.model.ener_model import ( + EnergyModel, + ) + from deepmd.pt.model.model.polar_model import ( + PolarModel, + ) + + # according to the fitting network to decide the type of the model + if cls is DPModel: + # map fitting to model + if isinstance(fitting, EnergyFittingNet): + cls = EnergyModel + elif isinstance(fitting, DipoleFittingNet): + cls = DipoleModel + elif isinstance(fitting, PolarFittingNet): + cls = PolarModel + # else: unknown fitting type, fall back to DPModel + return super().__new__(cls) diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index 0fd8008f21..c8264f2007 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -9,6 +9,9 @@ from deepmd.pt.model.atomic_model import ( DPZBLLinearAtomicModel, ) +from deepmd.pt.model.model.model import ( + BaseModel, +) from .make_model import ( make_model, @@ -17,7 +20,8 @@ DPZBLModel_ = make_model(DPZBLLinearAtomicModel) -class DPZBLModel(DPZBLModel_): +@BaseModel.register("zbl") +class DPZBLModel(DPZBLModel_, BaseModel): model_type = "ener" def __init__( diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index d98d25d539..0f5e27aea9 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -3,14 +3,62 @@ Optional, ) -import torch - +from deepmd.dpmodel.model.base_model import ( + make_base_model, +) from deepmd.utils.path import ( DPPath, ) -class BaseModel(torch.nn.Module): +# trick: torch.nn.Module should not be inherbited here, otherwise, +# the abstract method will override the method from the atomic model +# as Python resolves method lookups using the C3 linearisation. +# See https://stackoverflow.com/a/47117600/9567349 +# Take an example, this is the situation for only inheriting make_model(): +# torch.nn.Module BaseAtomicModel make_model() +# | | | +# ------------------------- | +# | | +# DPAtomicModel BaseModel +# | | +# make_model(DPAtomicModel) | +# | | +# ---------------------------------- +# | +# DPModel +# +# The order is: DPModel -> make_model(DPAtomicModel) -> DPAtomicModel -> +# torch.nn.Module -> BaseAtomicModel -> BaseModel -> make_model() +# +# However, if BaseModel also inherbits from torch.nn.Module: +# torch.nn.Module make_model() +# | | +# |--------------------------- | +# | | | +# | BaseAtomicModel | | +# | | | | +# |------------- ---------- +# | | +# DPAtomicModel BaseModel +# | | +# | | +# make_model(DPAtomicModel) | +# | | +# | | +# -------------------------------- +# | +# | +# DPModel +# +# The order is DPModel -> make_model(DPAtomicModel) -> DPAtomicModel -> +# BaseModel -> torch.nn.Module -> BaseAtomicModel -> make_model() +# BaseModel has higher proirity than BaseAtomicModel, which is not what +# we want. +# Alternatively, we can also make BaseAtomicModel in front of torch.nn.Module +# in DPAtomicModel (and other classes), but this requires the developer aware +# of it when developing it... +class BaseModel(make_base_model()): def __init__(self): """Construct a basic model for different tasks.""" super().__init__() diff --git a/deepmd/pt/utils/serialization.py b/deepmd/pt/utils/serialization.py index 91d1a3c76f..c99ddbb3c6 100644 --- a/deepmd/pt/utils/serialization.py +++ b/deepmd/pt/utils/serialization.py @@ -6,8 +6,8 @@ from deepmd.pt.model.model import ( get_model, ) -from deepmd.pt.model.model.ener_model import ( - EnergyModel, +from deepmd.pt.model.model.model import ( + BaseModel, ) from deepmd.pt.train.wrapper import ( ModelWrapper, @@ -68,8 +68,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None: """ if not model_file.endswith(".pth"): raise ValueError("PyTorch backend only supports converting .pth file") - # TODO: read class type from data; see #3319 - model = EnergyModel.deserialize(data["model"]) + model = BaseModel.deserialize(data["model"]) # JIT will happy in this way... model.model_def_script = json.dumps(data["model_def_script"]) model = torch.jit.script(model) diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index 8af4771ff6..76310834a7 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -842,6 +842,8 @@ def serialize(self, suffix: str = "") -> dict: if self.spin is not None: raise NotImplementedError("spin is not supported") return { + "@class": "Model", + "type": "standard", "type_map": self.type_map, "descriptor": self.descrpt.serialize(suffix=suffix), "fitting": self.fitting.serialize(suffix=suffix), diff --git a/deepmd/utils/plugin.py b/deepmd/utils/plugin.py index a564ed61af..e6433ee681 100644 --- a/deepmd/utils/plugin.py +++ b/deepmd/utils/plugin.py @@ -2,11 +2,14 @@ """Base of plugin systems.""" # copied from https://github.com/deepmodeling/dpdata/blob/a3e76d75de53f6076254de82d18605a010dc3b00/dpdata/plugin.py +import difflib from abc import ( ABCMeta, ) from typing import ( Callable, + Optional, + Type, ) @@ -93,3 +96,60 @@ class PluginVariant(metaclass=VariantABCMeta): """A class to remove `type` from input arguments.""" pass + + +def make_plugin_registry(name: Optional[str] = None) -> Type[object]: + """Make a plugin registry. + + Parameters + ---------- + name : Optional[str] + the name of the registry for the error message, e.g. descriptor, backend, etc. + + Examples + -------- + >>> class BaseClass(make_plugin_registry()): + pass + """ + if name is None: + name = "class" + + class PR: + __plugins = Plugin() + + @staticmethod + def register(key: str) -> Callable[[object], object]: + """Register a descriptor plugin. + + Parameters + ---------- + key : str + the key of a descriptor + + Returns + ------- + callable[[object], object] + the registered descriptor + + Examples + -------- + >>> @BaseClass.register("some_class") + class SomeClass(BaseClass): + pass + """ + return PR.__plugins.register(key) + + @classmethod + def get_class_by_type(cls, class_type: str) -> Type[object]: + """Get the class by the plugin type.""" + if class_type in PR.__plugins.plugins: + return PR.__plugins.plugins[class_type] + else: + # did you mean + matches = difflib.get_close_matches( + class_type, PR.__plugins.plugins.keys() + ) + dym_message = f"Did you mean: {matches[0]}?" if matches else "" + raise RuntimeError(f"Unknown {name} type: {class_type}. {dym_message}") + + return PR diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 7b6d374168..be599b0805 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -16,8 +16,8 @@ from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, ) -from deepmd.infer.deep_pot import ( - DeepPot, +from deepmd.infer.deep_eval import ( + DeepEval, ) infer_path = Path(__file__).parent.parent.parent / "infer" @@ -121,7 +121,7 @@ def test_deep_eval(self): continue reference_data = copy.deepcopy(self.data) self.save_data_to_model(prefix + backend.suffixes[0], reference_data) - deep_eval = DeepPot(prefix + backend.suffixes[0]) + deep_eval = DeepEval(prefix + backend.suffixes[0]) ret = deep_eval.eval( self.coords, self.box,