Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: refactor dpmodel #3663

Merged
merged 44 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
02e546c
chore: try remove dp model
anyangml Apr 11, 2024
4c026fb
chore: try remove dp model
anyangml Apr 11, 2024
81573d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
3ad9637
fix: import
anyangml Apr 11, 2024
c64e520
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
acd1b99
fix: import
anyangml Apr 11, 2024
db2751c
fix: UTs
anyangml Apr 11, 2024
9612b02
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
a7384d2
fix: UTs
anyangml Apr 11, 2024
30bd378
fix: UTs
anyangml Apr 11, 2024
6c9185e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
24859ac
fix: UTs
anyangml Apr 11, 2024
754e18a
fix: UTs
anyangml Apr 11, 2024
56accdf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
9f437ed
fix: address comments
anyangml Apr 12, 2024
ed34307
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
72a34fc
fix: import
anyangml Apr 12, 2024
5355428
fix: UTs
anyangml Apr 12, 2024
257936d
feat: try remove standard
anyangml Apr 13, 2024
a51a792
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
4b14d54
fix: precommit
anyangml Apr 13, 2024
5484375
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
3c5c795
fix: import
anyangml Apr 13, 2024
224ef27
fix: UTs
anyangml Apr 13, 2024
bd079b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
53dbe13
fix: UTs
anyangml Apr 13, 2024
78ff90d
fix: UTs
anyangml Apr 13, 2024
7f99512
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
23fa34d
fix: try modify argcheck
anyangml Apr 13, 2024
16583de
fix: try modify argcheck
anyangml Apr 13, 2024
23fd899
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
4fe3661
fix: plugin
anyangml Apr 13, 2024
576e75d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
d34a7d4
Merge branch 'devel' into chore/refactor-dpmodel
anyangml Apr 14, 2024
3173978
fix:inheritance
anyangml Apr 14, 2024
1e54270
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2024
353ca6d
fix: revert changes
anyangml Apr 14, 2024
8de02ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2024
e2a80b3
fix: UTs
anyangml Apr 14, 2024
efcf951
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2024
0c100fa
fix:UTs
anyangml Apr 14, 2024
752132d
fix: UTs
anyangml Apr 14, 2024
ff30173
fix: tf register
anyangml Apr 15, 2024
7399d3e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepmd/dpmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
NativeOP,
)
from .model import (
DPModel,
DPModelCommon,
)
from .output_def import (
FittingOutputDef,
Expand All @@ -19,7 +19,7 @@
)

__all__ = [
"DPModel",
"DPModelCommon",
"PRECISION_DICT",
"DEFAULT_PRECISION",
"NativeOP",
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""

from .dp_model import (
DPModel,
DPModelCommon,
)
from .make_model import (
make_model,
Expand All @@ -23,7 +23,7 @@
)

__all__ = [
"DPModel",
"DPModelCommon",
"SpinModel",
"make_model",
]
17 changes: 14 additions & 3 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ class BaseBaseModel(ABC, PluginVariant, make_plugin_registry("model")):

def __new__(cls, *args, **kwargs):
if inspect.isabstract(cls):
cls = cls.get_class_by_type(kwargs.get("type", "standard"))
# getting model type based on fitting type
model_type = kwargs.get("type", "standard")
if model_type == "standard":
model_type = kwargs.get("fitting", {}).get("type", "ener")
cls = cls.get_class_by_type(model_type)
return super().__new__(cls)

@abstractmethod
Expand Down Expand Up @@ -118,7 +122,10 @@ def deserialize(cls, data: dict) -> "BaseBaseModel":
The deserialized model
"""
if inspect.isabstract(cls):
return cls.get_class_by_type(data["type"]).deserialize(data)
model_type = data.get("type", "standard")
if model_type == "standard":
model_type = data.get("fitting", {}).get("type", "ener")
return cls.get_class_by_type(model_type).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

model_def_script: str
Expand Down Expand Up @@ -151,7 +158,11 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict):
local_jdata : dict
The local data refer to the current class
"""
cls = cls.get_class_by_type(local_jdata.get("type", "standard"))
# getting model type based on fitting type
model_type = local_jdata.get("type", "standard")
if model_type == "standard":
model_type = local_jdata.get("fitting", {}).get("type", "ener")
cls = cls.get_class_by_type(model_type)
return cls.update_sel(global_jdata, local_jdata)

return BaseBaseModel
Expand Down
13 changes: 1 addition & 12 deletions deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)

from .make_model import (
make_model,
)


# use "class" to resolve "Variable not allowed in type expression"
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel)):
class DPModelCommon:
@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Expand Down
27 changes: 27 additions & 0 deletions deepmd/dpmodel/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)

from .dp_model import (
DPModelCommon,
)
from .make_model import (
make_model,
)

DPEnergyModel_ = make_model(DPAtomicModel)


@BaseModel.register("ener")
class EnergyModel(DPModelCommon, DPEnergyModel_):
def __init__(
self,
*args,
**kwargs,
):
DPModelCommon.__init__(self)
DPEnergyModel_.__init__(self, *args, **kwargs)
10 changes: 5 additions & 5 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
)
from deepmd.dpmodel.model.dp_model import (
DPModel,
from deepmd.dpmodel.model.ener_model import (
EnergyModel,
)
from deepmd.dpmodel.model.spin_model import (
SpinModel,
Expand All @@ -16,8 +16,8 @@
)


def get_standard_model(data: dict) -> DPModel:
"""Get a standard DPModel from a dictionary.
def get_standard_model(data: dict) -> EnergyModel:
"""Get a EnergyModel from a dictionary.

Parameters
----------
Expand All @@ -41,7 +41,7 @@ def get_standard_model(data: dict) -> DPModel:
)
else:
raise ValueError(f"Unknown fitting type {fitting_type}")
return DPModel(
return EnergyModel(
descriptor=descriptor,
fitting=fitting,
type_map=data["type_map"],
Expand Down
11 changes: 8 additions & 3 deletions deepmd/dpmodel/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

import numpy as np

from deepmd.dpmodel.model.dp_model import (
DPModel,
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.model.make_model import (
make_model,
)
from deepmd.utils.spin import (
Spin,
Expand Down Expand Up @@ -259,7 +262,9 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "SpinModel":
backbone_model_obj = DPModel.deserialize(data["backbone_model"])
backbone_model_obj = make_model(DPAtomicModel).deserialize(
data["backbone_model"]
)
spin = Spin.deserialize(data["spin"])
return cls(
backbone_model=backbone_model_obj,
Expand Down
16 changes: 16 additions & 0 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,37 @@
from .base_atomic_model import (
BaseAtomicModel,
)
from .dipole_atomic_model import (
DPDipoleAtomicModel,
)
from .dos_atomic_model import (
DPDOSAtomicModel,
)
from .dp_atomic_model import (
DPAtomicModel,
)
from .energy_atomic_model import (
DPEnergyAtomicModel,
)
from .linear_atomic_model import (
DPZBLLinearEnergyAtomicModel,
LinearEnergyAtomicModel,
)
from .pairtab_atomic_model import (
PairTabAtomicModel,
)
from .polar_atomic_model import (
DPPolarAtomicModel,
)

__all__ = [
"BaseAtomicModel",
"DPAtomicModel",
"DPDOSAtomicModel",
"DPEnergyAtomicModel",
"PairTabAtomicModel",
"LinearEnergyAtomicModel",
"DPPolarAtomicModel",
"DPDipoleAtomicModel",
"DPZBLLinearEnergyAtomicModel",
]
28 changes: 28 additions & 0 deletions deepmd/pt/model/atomic_model/dipole_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
)

import torch

from deepmd.pt.model.task.dipole import (
DipoleFittingNet,
)

from .dp_atomic_model import (
DPAtomicModel,
)


class DPDipoleAtomicModel(DPAtomicModel):
anyangml marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, DipoleFittingNet)
super().__init__(descriptor, fitting, type_map, **kwargs)

def apply_out_stat(
self,
ret: Dict[str, torch.Tensor],
atype: torch.Tensor,
):
# dipole not applying bias
return ret
14 changes: 14 additions & 0 deletions deepmd/pt/model/atomic_model/dos_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.pt.model.task.dos import (
DOSFittingNet,
)

from .dp_atomic_model import (
DPAtomicModel,
)


class DPDOSAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, DOSFittingNet)
super().__init__(descriptor, fitting, type_map, **kwargs)
20 changes: 20 additions & 0 deletions deepmd/pt/model/atomic_model/energy_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.pt.model.task.ener import (
EnergyFittingNet,
EnergyFittingNetDirect,
InvarFitting,
)

from .dp_atomic_model import (
DPAtomicModel,
)


class DPEnergyAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert (
isinstance(fitting, EnergyFittingNet)
or isinstance(fitting, EnergyFittingNetDirect)
or isinstance(fitting, InvarFitting)
)
super().__init__(descriptor, fitting, type_map, **kwargs)
28 changes: 28 additions & 0 deletions deepmd/pt/model/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
)

import torch

from deepmd.pt.model.task.polarizability import (
PolarFittingNet,
)

from .dp_atomic_model import (
DPAtomicModel,
)


class DPPolarAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, PolarFittingNet)
super().__init__(descriptor, fitting, type_map, **kwargs)

def apply_out_stat(
self,
ret: Dict[str, torch.Tensor],
atype: torch.Tensor,
):
# TODO: migrate bias
return ret
26 changes: 23 additions & 3 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@
Spin,
)

from .dipole_model import (
DipoleModel,
)
from .dos_model import (
DOSModel,
)
from .dp_model import (
DPModel,
DPModelCommon,
)
from .dp_zbl_model import (
DPZBLModel,
Expand All @@ -51,6 +57,9 @@
from .model import (
BaseModel,
)
from .polar_model import (
PolarModel,
)
from .spin_model import (
SpinEnergyModel,
SpinModel,
Expand Down Expand Up @@ -160,7 +169,18 @@
atom_exclude_types = model_params.get("atom_exclude_types", [])
pair_exclude_types = model_params.get("pair_exclude_types", [])

model = DPModel(
if fitting_net["type"] == "dipole":
modelcls = DipoleModel
elif fitting_net["type"] == "polar":
modelcls = PolarModel
elif fitting_net["type"] == "dos":
modelcls = DOSModel
elif fitting_net["type"] in ["ener", "direct_force_ener"]:
modelcls = EnergyModel
else:
raise RuntimeError(f"Unknown fitting type: {fitting_net['type']}")

Check warning on line 181 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L181

Added line #L181 was not covered by tests

model = modelcls(
Fixed Show fixed Hide fixed
descriptor=descriptor,
fitting=fitting,
type_map=model_params["type_map"],
Expand All @@ -183,7 +203,7 @@
__all__ = [
"BaseModel",
"get_model",
"DPModel",
"DPModelCommon",
"EnergyModel",
"FrozenModel",
"SpinModel",
Expand Down
Loading