Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BaseModel; store type in serialization #3335

Merged
merged 17 commits into from
Feb 27, 2024
278 changes: 126 additions & 152 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,136 +6,142 @@
)
from typing import (
Any,
Callable,
List,
Type,
)

from deepmd.common import (
j_get_type,
)
from deepmd.utils.plugin import (
Plugin,
make_plugin_registry,
)


class BaseBaseModel(ABC):
"""Base class for final exported model that will be directly used for inference.

The class defines some abstractmethods that will be directly called by the
inference interface. If the final model class inherits some of those methods
from other classes, `BaseModel` should be inherited as the last class to ensure
the correct method resolution order.

This class is backend-indepedent.

See Also
--------
deepmd.dpmodel.model.base_model.BaseModel
BaseModel class for DPModel backend.
"""

@abstractmethod
def __call__(self, *args: Any, **kwds: Any) -> Any:
"""Inference method.

Parameters
----------
*args : Any
The input data for inference.
**kwds : Any
The input data for inference.

Returns
-------
Any
The output of the inference.
"""
pass

@abstractmethod
def get_type_map(self) -> List[str]:
"""Get the type map."""

@abstractmethod
def get_rcut(self):
"""Get the cut-off radius."""
def make_base_model() -> Type[object]:
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
class BaseBaseModel(ABC, make_plugin_registry("model")):
"""Base class for final exported model that will be directly used for inference.

@abstractmethod
def get_dim_fparam(self):
"""Get the number (dimension) of frame parameters of this atomic model."""
The class defines some abstractmethods that will be directly called by the
inference interface. If the final model class inherits some of those methods
from other classes, `BaseModel` should be inherited as the last class to ensure
the correct method resolution order.

@abstractmethod
def get_dim_aparam(self):
"""Get the number (dimension) of atomic parameters of this atomic model."""
This class is backend-indepedent.

@abstractmethod
def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.

Only atoms with selected atom types have atomic contribution
to the result of the model.
If returning an empty list, all atom types are selected.
"""

@abstractmethod
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).

If False, the shape is (nframes, nloc, ndim).
"""

@abstractmethod
def model_output_type(self) -> str:
"""Get the output type for the model."""

@classmethod
@abstractmethod
def get_class_by_type(cls, model_type: str) -> Type["BaseBaseModel"]:
"""Get the class by the type of the model.

Parameters
----------
model_type : str
The type of the model.

Returns
-------
Type["BaseBaseModel"]
The class of the model.
"""

@abstractmethod
def serialize(self) -> dict:
"""Serialize the model.

Returns
-------
dict
The serialized data
"""
pass

@classmethod
def deserialize(cls, data: dict) -> "BaseBaseModel":
"""Deserialize the model.

Parameters
----------
data : dict
The serialized data

Returns
-------
BaseModel
The deserialized model
See Also
--------
deepmd.dpmodel.model.base_model.BaseModel
BaseModel class for DPModel backend.
"""
if inspect.isabstract(cls):
return cls.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)


class BaseModel(BaseBaseModel):
def __new__(cls, *args, **kwargs):
if inspect.isabstract(cls):
cls = cls.get_class_by_type(kwargs.get("type", "standard"))
return super().__new__(cls)

@abstractmethod
def __call__(self, *args: Any, **kwds: Any) -> Any:
"""Inference method.

Parameters
----------
*args : Any
The input data for inference.
**kwds : Any
The input data for inference.

Returns
-------
Any
The output of the inference.
"""
pass

@abstractmethod
def get_type_map(self) -> List[str]:
"""Get the type map."""

@abstractmethod
def get_rcut(self):
"""Get the cut-off radius."""

@abstractmethod
def get_dim_fparam(self):
"""Get the number (dimension) of frame parameters of this atomic model."""

@abstractmethod
def get_dim_aparam(self):
"""Get the number (dimension) of atomic parameters of this atomic model."""

@abstractmethod
def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.

Only atoms with selected atom types have atomic contribution
to the result of the model.
If returning an empty list, all atom types are selected.
"""

@abstractmethod
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).

If False, the shape is (nframes, nloc, ndim).
"""

@abstractmethod
def model_output_type(self) -> str:
"""Get the output type for the model."""

@abstractmethod
def serialize(self) -> dict:
"""Serialize the model.

Returns
-------
dict
The serialized data
"""
pass

@classmethod
def deserialize(cls, data: dict) -> "BaseBaseModel":
"""Deserialize the model.

Parameters
----------
data : dict
The serialized data

Returns
-------
BaseModel
The deserialized model
"""
if inspect.isabstract(cls):
return cls.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

model_def_script: str

@abstractmethod
def get_model_def_script(self) -> str:
"""Get the model definition script."""
pass

@abstractmethod
def get_nnei(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
# for C++ interface
pass

@abstractmethod
def get_nsel(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
pass

return BaseBaseModel


class BaseModel(make_base_model()):
"""Base class for final exported model that will be directly used for inference.

The class defines some abstractmethods that will be directly called by the
Expand All @@ -151,38 +157,6 @@ class BaseModel(BaseBaseModel):
Backend-independent BaseModel class.
"""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable[[object], object]:
"""Register a descriptor plugin.

Parameters
----------
key : str
the key of a descriptor

Returns
-------
callable[[object], object]
the registered descriptor

Examples
--------
>>> @Model.register("some_model")
class SomeModel(Model):
pass
"""
return BaseModel.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is BaseModel:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, model_type: str) -> Type["BaseModel"]:
if model_type in BaseModel.__plugins.plugins:
return BaseModel.__plugins.plugins[model_type]
else:
raise RuntimeError("Unknown model type: " + model_type)
def get_model_def_script(self) -> str:
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
# TODO: implement this method; saved to model
raise NotImplementedError
Loading
Loading