Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BaseModel; store type in serialization #3335

Merged
merged 17 commits into from
Feb 27, 2024
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def forward_atomic(

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
Expand All @@ -138,6 +140,8 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
Expand Down
11 changes: 11 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from abc import (
abstractmethod,
Expand Down Expand Up @@ -182,12 +183,17 @@ def fitting_output_def(self) -> FittingOutputDef:
@staticmethod
def serialize(models) -> dict:
return {
"@class": "Model",
"type": "linear",
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
}

@staticmethod
def deserialize(data) -> List[BaseAtomicModel]:
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
model_names = data["model_name"]
models = [
getattr(sys.modules[__name__], name).deserialize(model)
Expand Down Expand Up @@ -263,6 +269,8 @@ def __init__(

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
Expand All @@ -271,6 +279,9 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
sw_rmin = data["sw_rmin"]
sw_rmax = data["sw_rmax"]
smin_alpha = data["smin_alpha"]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@
return self.fitting_output_def()[var_name].c_differentiable
return self.fitting_output_def()[var_name].r_differentiable

def get_model_def_script(self) -> str:
# TODO: implement this method; saved to model
raise NotImplementedError

Check warning on line 165 in deepmd/dpmodel/atomic_model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/make_base_atomic_model.py#L165

Added line #L165 was not covered by tests

setattr(BAM, fwd_method_name, BAM.fwd)
delattr(BAM, "fwd")

Expand Down
12 changes: 11 additions & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Dict,
List,
Expand Down Expand Up @@ -105,10 +106,19 @@ def mixed_types(self) -> bool:
return True

def serialize(self) -> dict:
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}
return {
"@class": "Model",
"type": "pairtab",
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
Expand Down
6 changes: 3 additions & 3 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

import numpy as np

from deepmd.dpmodel.model.dp_model import (
DPModel,
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.output_def import (
ModelOutputDef,
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
self.model_path = model_file

model_data = load_dp_model(model_file)
self.dp = DPModel.deserialize(model_data["model"])
self.dp = BaseModel.deserialize(model_data["model"])
self.rcut = self.dp.get_rcut()
self.type_map = self.dp.get_type_map()
if isinstance(auto_batch_size, bool):
Expand Down
158 changes: 158 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import inspect
from abc import (
ABC,
abstractmethod,
)
from typing import (
Any,
List,
Type,
)

from deepmd.utils.plugin import (
make_plugin_registry,
)


def make_base_model() -> Type[object]:
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
class BaseBaseModel(ABC, make_plugin_registry("model")):
"""Base class for final exported model that will be directly used for inference.

The class defines some abstractmethods that will be directly called by the
inference interface. If the final model class inherits some of those methods
from other classes, `BaseModel` should be inherited as the last class to ensure
the correct method resolution order.

This class is backend-indepedent.

See Also
--------
deepmd.dpmodel.model.base_model.BaseModel
BaseModel class for DPModel backend.
"""

def __new__(cls, *args, **kwargs):
if inspect.isabstract(cls):
cls = cls.get_class_by_type(kwargs.get("type", "standard"))

Check warning on line 37 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L37

Added line #L37 was not covered by tests
return super().__new__(cls)

@abstractmethod
def __call__(self, *args: Any, **kwds: Any) -> Any:
"""Inference method.

Parameters
----------
*args : Any
The input data for inference.
**kwds : Any
The input data for inference.

Returns
-------
Any
The output of the inference.
"""
pass

Check warning on line 56 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L56

Added line #L56 was not covered by tests

@abstractmethod
def get_type_map(self) -> List[str]:
"""Get the type map."""

@abstractmethod
def get_rcut(self):
"""Get the cut-off radius."""

@abstractmethod
def get_dim_fparam(self):
"""Get the number (dimension) of frame parameters of this atomic model."""

@abstractmethod
def get_dim_aparam(self):
"""Get the number (dimension) of atomic parameters of this atomic model."""

@abstractmethod
def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.

Only atoms with selected atom types have atomic contribution
to the result of the model.
If returning an empty list, all atom types are selected.
"""

@abstractmethod
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).

If False, the shape is (nframes, nloc, ndim).
"""

@abstractmethod
def model_output_type(self) -> str:
"""Get the output type for the model."""

@abstractmethod
def serialize(self) -> dict:
"""Serialize the model.

Returns
-------
dict
The serialized data
"""
pass

Check warning on line 103 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L103

Added line #L103 was not covered by tests

@classmethod
def deserialize(cls, data: dict) -> "BaseBaseModel":
"""Deserialize the model.

Parameters
----------
data : dict
The serialized data

Returns
-------
BaseModel
The deserialized model
"""
if inspect.isabstract(cls):
return cls.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

Check warning on line 121 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L121

Added line #L121 was not covered by tests

model_def_script: str

@abstractmethod
def get_model_def_script(self) -> str:
"""Get the model definition script."""
pass

Check warning on line 128 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L128

Added line #L128 was not covered by tests

@abstractmethod
def get_nnei(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
# for C++ interface
pass

Check warning on line 134 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L134

Added line #L134 was not covered by tests

@abstractmethod
def get_nsel(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
pass

Check warning on line 139 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L139

Added line #L139 was not covered by tests

return BaseBaseModel


class BaseModel(make_base_model()):
"""Base class for final exported model that will be directly used for inference.

The class defines some abstractmethods that will be directly called by the
inference interface. If the final model class inherbits some of those methods
from other classes, `BaseModel` should be inherited as the last class to ensure
the correct method resolution order.

This class is for the DPModel backend.

See Also
--------
deepmd.dpmodel.model.base_model.BaseBaseModel
Backend-independent BaseModel class.
"""
6 changes: 5 additions & 1 deletion deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)

from .make_model import (
make_model,
)


# use "class" to resolve "Variable not allowed in type expression"
class DPModel(make_model(DPAtomicModel)):
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
Fixed Show fixed Hide fixed
pass
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def mixed_types(self) -> bool:

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting_net.serialize(),
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def fitting_output_def(self) -> FittingOutputDef:
@staticmethod
def serialize(models) -> dict:
return {
"@class": "Model",
"type": "linear",
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
}
Expand Down Expand Up @@ -301,6 +303,8 @@ def __init__(

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
Expand Down
8 changes: 7 additions & 1 deletion deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,13 @@ def mixed_types(self) -> bool:
return True

def serialize(self) -> dict:
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}
return {
"@class": "Model",
"type": "pairtab",
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
Expand Down
38 changes: 37 additions & 1 deletion deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,45 @@
from deepmd.pt.model.atomic_model import (
DPAtomicModel,
)
from deepmd.pt.model.model.model import (
BaseModel,
)
from deepmd.pt.model.task.dipole import (
DipoleFittingNet,
)
from deepmd.pt.model.task.ener import (
EnergyFittingNet,
)
from deepmd.pt.model.task.polarizability import (
PolarFittingNet,
)

from .make_model import (
make_model,
)

DPModel = make_model(DPAtomicModel)

@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
Fixed Show fixed Hide fixed

Check warning

Code scanning / CodeQL

Conflicting attributes in base classes Warning

Base classes have conflicting values for attribute 'compute_or_load_stat':
Function compute_or_load_stat
and
Function compute_or_load_stat
.
Base classes have conflicting values for attribute 'compute_or_load_stat': Function compute_or_load_stat and
Function compute_or_load_stat
.
def __new__(cls, descriptor, fitting, *args, **kwargs):
from deepmd.pt.model.model.dipole_model import (
DipoleModel,
)
Comment on lines +26 to +28

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.dipole_model
begins an import cycle.
from deepmd.pt.model.model.ener_model import (
EnergyModel,
)
Comment on lines +29 to +31

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.ener_model
begins an import cycle.
from deepmd.pt.model.model.polar_model import (
PolarModel,
)
Comment on lines +32 to +34

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.polar_model
begins an import cycle.

# according to the fitting network to decide the type of the model
if cls is DPModel:
# map fitting to model
if isinstance(fitting, EnergyFittingNet):
cls = EnergyModel
elif isinstance(fitting, DipoleFittingNet):
cls = DipoleModel

Check warning on line 42 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L42

Added line #L42 was not covered by tests
elif isinstance(fitting, PolarFittingNet):
cls = PolarModel

Check warning on line 44 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L44

Added line #L44 was not covered by tests
# else: unknown fitting type, fall back to DPModel
return super().__new__(cls)
Loading