Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BaseModel; store type in serialization #3335

Merged
merged 17 commits into from
Feb 27, 2024
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
Expand All @@ -138,6 +140,8 @@
@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")

Check warning on line 144 in deepmd/dpmodel/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/dp_atomic_model.py#L143-L144

Added lines #L143 - L144 were not covered by tests
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
Expand Down
11 changes: 11 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from abc import (
abstractmethod,
Expand Down Expand Up @@ -182,12 +183,17 @@
@staticmethod
def serialize(models) -> dict:
return {
"@class": "Model",
"type": "linear",
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
}

@staticmethod
def deserialize(data) -> List[BaseAtomicModel]:
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")

Check warning on line 196 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L194-L196

Added lines #L194 - L196 were not covered by tests
model_names = data["model_name"]
models = [
getattr(sys.modules[__name__], name).deserialize(model)
Expand Down Expand Up @@ -263,6 +269,8 @@

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
Expand All @@ -271,6 +279,9 @@

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")

Check warning on line 284 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L282-L284

Added lines #L282 - L284 were not covered by tests
sw_rmin = data["sw_rmin"]
sw_rmax = data["sw_rmax"]
smin_alpha = data["smin_alpha"]
Expand Down
12 changes: 11 additions & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Dict,
List,
Expand Down Expand Up @@ -105,10 +106,19 @@
return True

def serialize(self) -> dict:
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}
return {

Check warning on line 109 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L109

Added line #L109 was not covered by tests
"@class": "Model",
"type": "pairtab",
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")

Check warning on line 121 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L119-L121

Added lines #L119 - L121 were not covered by tests
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
Expand Down
6 changes: 3 additions & 3 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

import numpy as np

from deepmd.dpmodel.model.dp_model import (
DPModel,
from deepmd.dpmodel.model.base_model import (

Check warning on line 16 in deepmd/dpmodel/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/infer/deep_eval.py#L16

Added line #L16 was not covered by tests
BaseModel,
)
from deepmd.dpmodel.output_def import (
ModelOutputDef,
Expand Down Expand Up @@ -85,7 +85,7 @@
self.model_path = model_file

model_data = load_dp_model(model_file)
self.dp = DPModel.deserialize(model_data["model"])
self.dp = BaseModel.deserialize(model_data["model"])

Check warning on line 88 in deepmd/dpmodel/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/infer/deep_eval.py#L88

Added line #L88 was not covered by tests
self.rcut = self.dp.get_rcut()
self.type_map = self.dp.get_type_map()
if isinstance(auto_batch_size, bool):
Expand Down
160 changes: 160 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractmethod,
)
from typing import (
Any,
Callable,
List,
Type,
)

from deepmd.common import (
j_get_type,
)
from deepmd.utils.plugin import (
Plugin,
)


class BaseBaseModel(ABC):
"""Base class for final exported model that will be directly used for inference.

The class defines some abstractmethods that will be directly called by the
inference interface. If the final model class inherbits some of those methods
njzjz marked this conversation as resolved.
Show resolved Hide resolved
from other classes, `BaseModel` should be inherited as the last class to ensure
the correct method resolution order.

This class is backend-indepedent.

See Also
--------
deepmd.dpmodel.model.base_model.BaseModel
BaseModel class for DPModel backend.
"""

@abstractmethod
def __call__(self, *args: Any, **kwds: Any) -> Any:
"""Inference method.

Parameters
----------
*args : Any
The input data for inference.
**kwds : Any
The input data for inference.

Returns
-------
Any
The output of the inference.
"""
pass

Check warning on line 53 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L53

Added line #L53 was not covered by tests

@abstractmethod
def get_type_map(self) -> List[str]:
"""Get the type map."""

@abstractmethod
def get_rcut(self):
"""Get the cut-off radius."""

@abstractmethod
def get_dim_fparam(self):
"""Get the number (dimension) of frame parameters of this atomic model."""

@abstractmethod
def get_dim_aparam(self):
"""Get the number (dimension) of atomic parameters of this atomic model."""

@abstractmethod
def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.

Only atoms with selected atom types have atomic contribution
to the result of the model.
If returning an empty list, all atom types are selected.
"""

@abstractmethod
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).

If False, the shape is (nframes, nloc, ndim).
"""

@abstractmethod
def model_output_type(self) -> str:
"""Get the output type for the model."""


class BaseModel(BaseBaseModel):
"""Base class for final exported model that will be directly used for inference.

The class defines some abstractmethods that will be directly called by the
inference interface. If the final model class inherbits some of those methods
from other classes, `BaseModel` should be inherited as the last class to ensure
the correct method resolution order.

This class is for the DPModel backend.

See Also
--------
deepmd.dpmodel.model.base_model.BaseBaseModel
Backend-independent BaseModel class.
"""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable[[object], object]:
"""Register a descriptor plugin.

Parameters
----------
key : str
the key of a descriptor

Returns
-------
callable[[object], object]
the registered descriptor

Examples
--------
>>> @Fitting.register("some_fitting")
class SomeFitting(Fitting):
pass
"""
return BaseModel.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is BaseModel:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

Check warning on line 135 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L133-L135

Added lines #L133 - L135 were not covered by tests

@classmethod
def get_class_by_type(cls, model_type: str) -> Type["BaseModel"]:
if model_type in BaseModel.__plugins.plugins:
return BaseModel.__plugins.plugins[model_type]

Check warning on line 140 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L139-L140

Added lines #L139 - L140 were not covered by tests
else:
raise RuntimeError("Unknown model type: " + model_type)

Check warning on line 142 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L142

Added line #L142 was not covered by tests

@classmethod
def deserialize(cls, data: dict) -> "BaseModel":
"""Deserialize the model.

Parameters
----------
data : dict
The serialized data

Returns
-------
BaseModel
The deserialized model
"""
if cls is BaseModel:
return BaseModel.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

Check warning on line 160 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L158-L160

Added lines #L158 - L160 were not covered by tests
6 changes: 5 additions & 1 deletion deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)

from .make_model import (
make_model,
)


# use "class" to resolve "Variable not allowed in type expression"
class DPModel(make_model(DPAtomicModel)):
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
Fixed Show fixed Hide fixed
pass
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def mixed_types(self) -> bool:

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting_net.serialize(),
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def fitting_output_def(self) -> FittingOutputDef:
@staticmethod
def serialize(models) -> dict:
return {
"@class": "Model",
"type": "linear",
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
}
Expand Down Expand Up @@ -301,6 +303,8 @@ def __init__(

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
Expand Down
8 changes: 7 additions & 1 deletion deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,13 @@
return True

def serialize(self) -> dict:
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}
return {

Check warning on line 127 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L127

Added line #L127 was not covered by tests
"@class": "Model",
"type": "pairtab",
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
Expand Down
39 changes: 38 additions & 1 deletion deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,46 @@
from deepmd.pt.model.atomic_model import (
DPAtomicModel,
)
from deepmd.pt.model.model.model import (

Check warning on line 5 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L5

Added line #L5 was not covered by tests
BaseModel,
)
from deepmd.pt.model.task.dipole import (

Check warning on line 8 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L8

Added line #L8 was not covered by tests
DipoleFittingNet,
)
from deepmd.pt.model.task.ener import (

Check warning on line 11 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L11

Added line #L11 was not covered by tests
EnergyFittingNet,
)
from deepmd.pt.model.task.polarizability import (

Check warning on line 14 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L14

Added line #L14 was not covered by tests
PolarFittingNet,
)

from .make_model import (
make_model,
)

DPModel = make_model(DPAtomicModel)

@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
Fixed Show fixed Hide fixed

Check warning

Code scanning / CodeQL

Conflicting attributes in base classes Warning

Base classes have conflicting values for attribute 'compute_or_load_stat':
Function compute_or_load_stat
and
Function compute_or_load_stat
.
Base classes have conflicting values for attribute 'compute_or_load_stat': Function compute_or_load_stat and
Function compute_or_load_stat
.
def __new__(cls, descriptor, fitting, **kwargs):
from deepmd.pt.model.model.dipole_model import (

Check warning on line 26 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L23-L26

Added lines #L23 - L26 were not covered by tests
DipoleModel,
)
Comment on lines +26 to +28

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.dipole_model
begins an import cycle.
from deepmd.pt.model.model.ener_model import (

Check warning on line 29 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L29

Added line #L29 was not covered by tests
EnergyModel,
)
Comment on lines +29 to +31

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.ener_model
begins an import cycle.
from deepmd.pt.model.model.polar_model import (

Check warning on line 32 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L32

Added line #L32 was not covered by tests
PolarModel,
)
Comment on lines +32 to +34

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.polar_model
begins an import cycle.

# according to the fitting network to decide the type of the model
if cls is DPModel:

Check warning on line 37 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L37

Added line #L37 was not covered by tests
# map fitting to model
if isinstance(fitting, EnergyFittingNet):
cls = EnergyModel
elif isinstance(fitting, DipoleFittingNet):
cls = DipoleModel
elif isinstance(fitting, PolarFittingNet):
cls = PolarModel

Check warning on line 44 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L39-L44

Added lines #L39 - L44 were not covered by tests
else:
raise ValueError(f"Unknown fitting type {fitting}")
return super().__new__(cls)

Check warning on line 47 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L46-L47

Added lines #L46 - L47 were not covered by tests
6 changes: 5 additions & 1 deletion deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from deepmd.pt.model.atomic_model import (
DPZBLLinearAtomicModel,
)
from deepmd.pt.model.model.model import (

Check warning on line 12 in deepmd/pt/model/model/dp_zbl_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_zbl_model.py#L12

Added line #L12 was not covered by tests
BaseModel,
)

from .make_model import (
make_model,
Expand All @@ -17,7 +20,8 @@
DPZBLModel_ = make_model(DPZBLLinearAtomicModel)


class DPZBLModel(DPZBLModel_):
@BaseModel.register("zbl")
class DPZBLModel(DPZBLModel_, BaseModel):

Check warning on line 24 in deepmd/pt/model/model/dp_zbl_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_zbl_model.py#L23-L24

Added lines #L23 - L24 were not covered by tests

Check warning

Code scanning / CodeQL

Conflicting attributes in base classes Warning

Base classes have conflicting values for attribute 'compute_or_load_stat':
Function compute_or_load_stat
and
Function compute_or_load_stat
.
Base classes have conflicting values for attribute 'compute_or_load_stat': Function compute_or_load_stat and
Function compute_or_load_stat
.
model_type = "ener"

def __init__(
Expand Down
Loading
Loading