Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BaseModel; store type in serialization #3335

Merged
merged 17 commits into from
Feb 27, 2024
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def forward_atomic(

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
Expand All @@ -138,6 +140,8 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
Expand Down
11 changes: 11 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from abc import (
abstractmethod,
Expand Down Expand Up @@ -182,12 +183,17 @@ def fitting_output_def(self) -> FittingOutputDef:
@staticmethod
def serialize(models) -> dict:
return {
"@class": "Model",
"type": "linear",
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
}

@staticmethod
def deserialize(data) -> List[BaseAtomicModel]:
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
model_names = data["model_name"]
models = [
getattr(sys.modules[__name__], name).deserialize(model)
Expand Down Expand Up @@ -263,6 +269,8 @@ def __init__(

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
Expand All @@ -271,6 +279,9 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
sw_rmin = data["sw_rmin"]
sw_rmax = data["sw_rmax"]
smin_alpha = data["smin_alpha"]
Expand Down
12 changes: 11 additions & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Dict,
List,
Expand Down Expand Up @@ -105,10 +106,19 @@ def mixed_types(self) -> bool:
return True

def serialize(self) -> dict:
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}
return {
"@class": "Model",
"type": "pairtab",
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
Expand Down
6 changes: 3 additions & 3 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

import numpy as np

from deepmd.dpmodel.model.dp_model import (
DPModel,
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.output_def import (
ModelOutputDef,
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
self.model_path = model_file

model_data = load_dp_model(model_file)
self.dp = DPModel.deserialize(model_data["model"])
self.dp = BaseModel.deserialize(model_data["model"])
self.rcut = self.dp.get_rcut()
self.type_map = self.dp.get_type_map()
if isinstance(auto_batch_size, bool):
Expand Down
188 changes: 188 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import inspect
from abc import (
ABC,
abstractmethod,
)
from typing import (
Any,
Callable,
List,
Type,
)

from deepmd.common import (
j_get_type,
)
from deepmd.utils.plugin import (
Plugin,
)


class BaseBaseModel(ABC):
"""Base class for final exported model that will be directly used for inference.

The class defines some abstractmethods that will be directly called by the
inference interface. If the final model class inherits some of those methods
from other classes, `BaseModel` should be inherited as the last class to ensure
the correct method resolution order.

This class is backend-indepedent.

See Also
--------
deepmd.dpmodel.model.base_model.BaseModel
BaseModel class for DPModel backend.
"""

@abstractmethod
def __call__(self, *args: Any, **kwds: Any) -> Any:
"""Inference method.

Parameters
----------
*args : Any
The input data for inference.
**kwds : Any
The input data for inference.

Returns
-------
Any
The output of the inference.
"""
pass

Check warning on line 54 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L54

Added line #L54 was not covered by tests

@abstractmethod
def get_type_map(self) -> List[str]:
"""Get the type map."""

@abstractmethod
def get_rcut(self):
"""Get the cut-off radius."""

@abstractmethod
def get_dim_fparam(self):
"""Get the number (dimension) of frame parameters of this atomic model."""

@abstractmethod
def get_dim_aparam(self):
"""Get the number (dimension) of atomic parameters of this atomic model."""

@abstractmethod
def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.

Only atoms with selected atom types have atomic contribution
to the result of the model.
If returning an empty list, all atom types are selected.
"""

@abstractmethod
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).

If False, the shape is (nframes, nloc, ndim).
"""

@abstractmethod
def model_output_type(self) -> str:
"""Get the output type for the model."""

@classmethod
@abstractmethod
def get_class_by_type(cls, model_type: str) -> Type["BaseBaseModel"]:
"""Get the class by the type of the model.

Parameters
----------
model_type : str
The type of the model.

Returns
-------
Type["BaseBaseModel"]
The class of the model.
"""

@abstractmethod
def serialize(self) -> dict:
"""Serialize the model.

Returns
-------
dict
The serialized data
"""
pass

Check warning on line 117 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L117

Added line #L117 was not covered by tests

@classmethod
def deserialize(cls, data: dict) -> "BaseBaseModel":
"""Deserialize the model.

Parameters
----------
data : dict
The serialized data

Returns
-------
BaseModel
The deserialized model
"""
if inspect.isabstract(cls):
return cls.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

Check warning on line 135 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L135

Added line #L135 was not covered by tests


class BaseModel(BaseBaseModel):
"""Base class for final exported model that will be directly used for inference.

The class defines some abstractmethods that will be directly called by the
inference interface. If the final model class inherbits some of those methods
from other classes, `BaseModel` should be inherited as the last class to ensure
the correct method resolution order.

This class is for the DPModel backend.

See Also
--------
deepmd.dpmodel.model.base_model.BaseBaseModel
Backend-independent BaseModel class.
"""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable[[object], object]:
"""Register a descriptor plugin.

Parameters
----------
key : str
the key of a descriptor

Returns
-------
callable[[object], object]
the registered descriptor

Examples
--------
>>> @Model.register("some_model")
class SomeModel(Model):
pass
"""
return BaseModel.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is BaseModel:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))

Check warning on line 180 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L180

Added line #L180 was not covered by tests
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, model_type: str) -> Type["BaseModel"]:
if model_type in BaseModel.__plugins.plugins:
return BaseModel.__plugins.plugins[model_type]
else:
raise RuntimeError("Unknown model type: " + model_type)

Check warning on line 188 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L188

Added line #L188 was not covered by tests
6 changes: 5 additions & 1 deletion deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)

from .make_model import (
make_model,
)


# use "class" to resolve "Variable not allowed in type expression"
class DPModel(make_model(DPAtomicModel)):
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
Fixed Show fixed Hide fixed
pass
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def mixed_types(self) -> bool:

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting_net.serialize(),
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def fitting_output_def(self) -> FittingOutputDef:
@staticmethod
def serialize(models) -> dict:
return {
"@class": "Model",
"type": "linear",
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
}
Expand Down Expand Up @@ -301,6 +303,8 @@ def __init__(

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
Expand Down
8 changes: 7 additions & 1 deletion deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,13 @@ def mixed_types(self) -> bool:
return True

def serialize(self) -> dict:
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}
return {
"@class": "Model",
"type": "pairtab",
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
Expand Down
Loading