Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: refactor dpmodel #3663

Merged
merged 44 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
02e546c
chore: try remove dp model
anyangml Apr 11, 2024
4c026fb
chore: try remove dp model
anyangml Apr 11, 2024
81573d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
3ad9637
fix: import
anyangml Apr 11, 2024
c64e520
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
acd1b99
fix: import
anyangml Apr 11, 2024
db2751c
fix: UTs
anyangml Apr 11, 2024
9612b02
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
a7384d2
fix: UTs
anyangml Apr 11, 2024
30bd378
fix: UTs
anyangml Apr 11, 2024
6c9185e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
24859ac
fix: UTs
anyangml Apr 11, 2024
754e18a
fix: UTs
anyangml Apr 11, 2024
56accdf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
9f437ed
fix: address comments
anyangml Apr 12, 2024
ed34307
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
72a34fc
fix: import
anyangml Apr 12, 2024
5355428
fix: UTs
anyangml Apr 12, 2024
257936d
feat: try remove standard
anyangml Apr 13, 2024
a51a792
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
4b14d54
fix: precommit
anyangml Apr 13, 2024
5484375
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
3c5c795
fix: import
anyangml Apr 13, 2024
224ef27
fix: UTs
anyangml Apr 13, 2024
bd079b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
53dbe13
fix: UTs
anyangml Apr 13, 2024
78ff90d
fix: UTs
anyangml Apr 13, 2024
7f99512
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
23fa34d
fix: try modify argcheck
anyangml Apr 13, 2024
16583de
fix: try modify argcheck
anyangml Apr 13, 2024
23fd899
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
4fe3661
fix: plugin
anyangml Apr 13, 2024
576e75d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
d34a7d4
Merge branch 'devel' into chore/refactor-dpmodel
anyangml Apr 14, 2024
3173978
fix:inheritance
anyangml Apr 14, 2024
1e54270
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2024
353ca6d
fix: revert changes
anyangml Apr 14, 2024
8de02ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2024
e2a80b3
fix: UTs
anyangml Apr 14, 2024
efcf951
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2024
0c100fa
fix:UTs
anyangml Apr 14, 2024
752132d
fix: UTs
anyangml Apr 14, 2024
ff30173
fix: tf register
anyangml Apr 15, 2024
7399d3e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from .base_atomic_model import (
BaseAtomicModel,
)
from .dipole_atomic_model import (

Check warning on line 20 in deepmd/pt/model/atomic_model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/__init__.py#L20

Added line #L20 was not covered by tests
DPDipoleAtomicModel,
)
from .dp_atomic_model import (
DPAtomicModel,
)
Expand All @@ -27,11 +30,16 @@
from .pairtab_atomic_model import (
PairTabAtomicModel,
)
from .polar_atomic_model import (

Check warning on line 33 in deepmd/pt/model/atomic_model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/__init__.py#L33

Added line #L33 was not covered by tests
DPPolarAtomicModel,
)

__all__ = [
"BaseAtomicModel",
"DPAtomicModel",
"PairTabAtomicModel",
"LinearEnergyAtomicModel",
"DPPolarAtomicModel",
"DPDipoleAtomicModel",
"DPZBLLinearEnergyAtomicModel",
]
10 changes: 10 additions & 0 deletions deepmd/pt/model/atomic_model/dipole_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .dp_atomic_model import (

Check warning on line 2 in deepmd/pt/model/atomic_model/dipole_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dipole_atomic_model.py#L2

Added line #L2 was not covered by tests
DPAtomicModel,
)


class DPDipoleAtomicModel(DPAtomicModel):
anyangml marked this conversation as resolved.
Show resolved Hide resolved
def apply_out_stat(self, ret, atype):

Check warning on line 8 in deepmd/pt/model/atomic_model/dipole_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dipole_atomic_model.py#L7-L8

Added lines #L7 - L8 were not covered by tests
# dipole not applying bias
pass

Check warning on line 10 in deepmd/pt/model/atomic_model/dipole_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dipole_atomic_model.py#L10

Added line #L10 was not covered by tests
10 changes: 10 additions & 0 deletions deepmd/pt/model/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .dp_atomic_model import (

Check warning on line 2 in deepmd/pt/model/atomic_model/polar_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/polar_atomic_model.py#L2

Added line #L2 was not covered by tests
DPAtomicModel,
)


class DPPolarAtomicModel(DPAtomicModel):
def apply_out_stat(self, ret, atype):

Check warning on line 8 in deepmd/pt/model/atomic_model/polar_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/polar_atomic_model.py#L7-L8

Added lines #L7 - L8 were not covered by tests
# TODO: migrate bias
pass

Check warning on line 10 in deepmd/pt/model/atomic_model/polar_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/polar_atomic_model.py#L10

Added line #L10 was not covered by tests
22 changes: 21 additions & 1 deletion deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
Spin,
)

from .dipole_model import (

Check warning on line 33 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L33

Added line #L33 was not covered by tests
DipoleModel,
)
from .dos_model import (

Check warning on line 36 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L36

Added line #L36 was not covered by tests
DOSModel,
)
from .dp_model import (
DPModel,
)
Expand All @@ -51,6 +57,9 @@
from .model import (
BaseModel,
)
from .polar_model import (

Check warning on line 60 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L60

Added line #L60 was not covered by tests
PolarModel,
)
from .spin_model import (
SpinEnergyModel,
SpinModel,
Expand Down Expand Up @@ -160,7 +169,18 @@
atom_exclude_types = model_params.get("atom_exclude_types", [])
pair_exclude_types = model_params.get("pair_exclude_types", [])

model = DPModel(
if fitting_net["type"] == "dipole":
modelcls = DipoleModel
elif fitting_net["type"] == "polar":
modelcls = PolarModel
elif fitting_net["type"] == "dos":
modelcls = DOSModel
elif fitting_net["type"] == "ener":
modelcls = EnergyModel

Check warning on line 179 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L172-L179

Added lines #L172 - L179 were not covered by tests
else:
pass

Check warning on line 181 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L181

Added line #L181 was not covered by tests

model = modelcls(

Check warning on line 183 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L183

Added line #L183 was not covered by tests
Fixed Show fixed Hide fixed
descriptor=descriptor,
fitting=fitting,
type_map=model_params["type_map"],
Expand Down
21 changes: 20 additions & 1 deletion deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,23 @@

import torch

from deepmd.pt.model.atomic_model import (

Check warning on line 9 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L9

Added line #L9 was not covered by tests
DPDipoleAtomicModel,
)
from deepmd.pt.model.model.model import (

Check warning on line 12 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L12

Added line #L12 was not covered by tests
BaseModel,
)

from .dp_model import (
DPModel,
)
from .make_model import (

Check warning on line 19 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L19

Added line #L19 was not covered by tests
make_model,
)


class DipoleModel(DPModel):
@BaseModel.register("standard")
class DipoleModel(DPModel, make_model(DPDipoleAtomicModel)):

Check warning on line 25 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L24-L25

Added lines #L24 - L25 were not covered by tests
model_type = "dipole"

def __init__(
Expand Down Expand Up @@ -57,6 +68,14 @@
model_predict["updated_coord"] += coord
return model_predict

def get_fitting_net(self):

Check warning on line 71 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L71

Added line #L71 was not covered by tests
"""Get the fitting network."""
return self.atomic_model.fitting_net

Check warning on line 73 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L73

Added line #L73 was not covered by tests

def get_descriptor(self):

Check warning on line 75 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L75

Added line #L75 was not covered by tests
"""Get the descriptor."""
return self.atomic_model.descriptor

Check warning on line 77 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L77

Added line #L77 was not covered by tests

@torch.jit.export
def forward_lower(
self,
Expand Down
21 changes: 20 additions & 1 deletion deepmd/pt/model/model/dos_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,23 @@

import torch

from deepmd.pt.model.atomic_model import (

Check warning on line 9 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L9

Added line #L9 was not covered by tests
DPAtomicModel,
)
from deepmd.pt.model.model.model import (

Check warning on line 12 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L12

Added line #L12 was not covered by tests
BaseModel,
)

from .dp_model import (
DPModel,
)
from .make_model import (

Check warning on line 19 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L19

Added line #L19 was not covered by tests
make_model,
)


class DOSModel(DPModel):
@BaseModel.register("standard")
class DOSModel(DPModel, make_model(DPAtomicModel)):

Check warning on line 25 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L24-L25

Added lines #L24 - L25 were not covered by tests
model_type = "dos"

def __init__(
Expand Down Expand Up @@ -50,6 +61,14 @@
model_predict["updated_coord"] += coord
return model_predict

def get_fitting_net(self):

Check warning on line 64 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L64

Added line #L64 was not covered by tests
"""Get the fitting network."""
return self.atomic_model.fitting_net

Check warning on line 66 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L66

Added line #L66 was not covered by tests

def get_descriptor(self):

Check warning on line 68 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L68

Added line #L68 was not covered by tests
"""Get the descriptor."""
return self.atomic_model.descriptor

Check warning on line 70 in deepmd/pt/model/model/dos_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dos_model.py#L70

Added line #L70 was not covered by tests

@torch.jit.export
def get_numb_dos(self) -> int:
"""Get the number of DOS for DOSFittingNet."""
Expand Down
103 changes: 2 additions & 101 deletions deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
Optional,
)

import torch

from deepmd.pt.model.atomic_model import (
DPAtomicModel,
)
from deepmd.pt.model.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.pt.model.model.model import (
BaseModel,
)
from deepmd.pt.model.task.dipole import (
DipoleFittingNet,
)
from deepmd.pt.model.task.dos import (
DOSFittingNet,
)
from deepmd.pt.model.task.ener import (
EnergyFittingNet,
EnergyFittingNetDirect,
)
from deepmd.pt.model.task.polarizability import (
PolarFittingNet,
)

from .make_model import (
make_model,
)


@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel)):
def __new__(
cls,
descriptor=None,
fitting=None,
*args,
# disallow positional atomic_model_
atomic_model_: Optional[DPAtomicModel] = None,
**kwargs,
):
from deepmd.pt.model.model.dipole_model import (
DipoleModel,
)
from deepmd.pt.model.model.dos_model import (
DOSModel,
)
from deepmd.pt.model.model.ener_model import (
EnergyModel,
)
from deepmd.pt.model.model.polar_model import (
PolarModel,
)

if atomic_model_ is not None:
fitting = atomic_model_.fitting_net
else:
assert fitting is not None, "fitting network is not provided"

# according to the fitting network to decide the type of the model
if cls is DPModel:
# map fitting to model
if isinstance(fitting, EnergyFittingNet) or isinstance(
fitting, EnergyFittingNetDirect
):
cls = EnergyModel
elif isinstance(fitting, DipoleFittingNet):
cls = DipoleModel
elif isinstance(fitting, PolarFittingNet):
cls = PolarModel
elif isinstance(fitting, DOSFittingNet):
cls = DOSModel
# else: unknown fitting type, fall back to DPModel
return super().__new__(cls)
class DPModel:

Check warning on line 7 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L7

Added line #L7 was not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved
"""A base class to implement common methods for all the Models."""

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
Expand All @@ -95,30 +23,3 @@
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy

def get_fitting_net(self):
"""Get the fitting network."""
return self.atomic_model.fitting_net

def get_descriptor(self):
"""Get the descriptor."""
return self.atomic_model.descriptor
anyangml marked this conversation as resolved.
Show resolved Hide resolved

def forward(
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
# directly call the forward_common method when no specific transform rule
return self.forward_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
25 changes: 4 additions & 21 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@

import torch

from deepmd.dpmodel.model.dp_model import (
DPModel,
)
from deepmd.pt.model.atomic_model import (
DPZBLLinearEnergyAtomicModel,
)
from deepmd.pt.model.model.model import (
BaseModel,
)

from .dp_model import (

Check warning on line 16 in deepmd/pt/model/model/dp_zbl_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_zbl_model.py#L16

Added line #L16 was not covered by tests
DPModel,
)
from .make_model import (
make_model,
)
Expand All @@ -24,7 +24,7 @@


@BaseModel.register("zbl")
class DPZBLModel(DPZBLModel_):
class DPZBLModel(DPModel, DPZBLModel_):

Check warning on line 27 in deepmd/pt/model/model/dp_zbl_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_zbl_model.py#L27

Added line #L27 was not covered by tests
model_type = "ener"

def __init__(
Expand Down Expand Up @@ -103,20 +103,3 @@
assert model_ret["dforce"] is not None
model_predict["dforce"] = model_ret["dforce"]
return model_predict

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["dpmodel"] = DPModel.update_sel(
global_jdata, local_jdata["dpmodel"]
)
return local_jdata_cpy
21 changes: 20 additions & 1 deletion deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,23 @@

import torch

from deepmd.pt.model.atomic_model import (

Check warning on line 9 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L9

Added line #L9 was not covered by tests
DPAtomicModel,
)
from deepmd.pt.model.model.model import (

Check warning on line 12 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L12

Added line #L12 was not covered by tests
BaseModel,
)

from .dp_model import (
DPModel,
)
from .make_model import (

Check warning on line 19 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L19

Added line #L19 was not covered by tests
make_model,
)


class EnergyModel(DPModel):
@BaseModel.register("standard")
anyangml marked this conversation as resolved.
Show resolved Hide resolved
class EnergyModel(DPModel, make_model(DPAtomicModel)):

Check warning on line 25 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L24-L25

Added lines #L24 - L25 were not covered by tests
model_type = "ener"

def __init__(
Expand Down Expand Up @@ -59,6 +70,14 @@
model_predict["updated_coord"] += coord
return model_predict

def get_fitting_net(self):

Check warning on line 73 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L73

Added line #L73 was not covered by tests
"""Get the fitting network."""
return self.atomic_model.fitting_net

Check warning on line 75 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L75

Added line #L75 was not covered by tests

def get_descriptor(self):

Check warning on line 77 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L77

Added line #L77 was not covered by tests
"""Get the descriptor."""
return self.atomic_model.descriptor

Check warning on line 79 in deepmd/pt/model/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/ener_model.py#L79

Added line #L79 was not covered by tests

@torch.jit.export
def forward_lower(
self,
Expand Down
Loading
Loading