Skip to content

Commit

Permalink
Feat: Add polar stat constant matrix calculation to PT (#3426)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Mar 11, 2024
1 parent a286bd4 commit 2ee8a3b
Show file tree
Hide file tree
Showing 13 changed files with 272 additions and 18 deletions.
10 changes: 10 additions & 0 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Dict,
Expand All @@ -19,6 +20,9 @@
OutputVariableDef,
fitting_check_output,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .general_fitting import (
GeneralFitting,
Expand Down Expand Up @@ -153,6 +157,12 @@ def serialize(self) -> dict:
data["c_differentiable"] = self.c_differentiable
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

def output_def(self):
return FittingOutputDef(
[
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from deepmd.dpmodel.fitting.general_fitting import (
GeneralFitting,
)
from deepmd.utils.version import (
check_version_compatibility,
)


@InvarFitting.register("ener")
Expand Down Expand Up @@ -69,6 +72,7 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)
Expand Down
4 changes: 0 additions & 4 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
FittingNet,
NetworkCollection,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .base_fitting import (
BaseFitting,
Expand Down Expand Up @@ -256,7 +253,6 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Dict,
Expand All @@ -16,6 +17,9 @@
OutputVariableDef,
fitting_check_output,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .general_fitting import (
GeneralFitting,
Expand Down Expand Up @@ -169,6 +173,12 @@ def serialize(self) -> dict:
data["atom_ener"] = self.atom_ener
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

def _net_out_dim(self):
"""Set the FittingNet output dim."""
return self.dim_out
Expand Down
35 changes: 35 additions & 0 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
Dict,
Expand All @@ -22,6 +23,9 @@
OutputVariableDef,
fitting_check_output,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .general_fitting import (
GeneralFitting,
Expand Down Expand Up @@ -139,6 +143,7 @@ def __init__(
ntypes, 1
)
self.shift_diag = shift_diag
self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION)
super().__init__(
var_name=var_name,
ntypes=ntypes,
Expand Down Expand Up @@ -168,15 +173,36 @@ def _net_out_dim(self):
else self.embedding_width * self.embedding_width
)

def __setitem__(self, key, value):
if key in ["constant_matrix"]:
self.constant_matrix = value
else:
super().__setitem__(key, value)

def __getitem__(self, key):
if key in ["constant_matrix"]:
return self.constant_matrix
else:
return super().__getitem__(key)

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "polar"
data["@version"] = 2
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["fit_diag"] = self.fit_diag
data["shift_diag"] = self.shift_diag
data["@variables"]["scale"] = self.scale
data["@variables"]["constant_matrix"] = self.constant_matrix
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
return super().deserialize(data)

def output_def(self):
return FittingOutputDef(
[
Expand Down Expand Up @@ -246,4 +272,13 @@ def call(
"bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out
) # (nframes * nloc, 3, 3)
out = out.reshape(nframes, nloc, 3, 3)
if self.shift_diag:
bias = self.constant_matrix[atype]
# (nframes, nloc, 1)
bias = np.expand_dims(bias, axis=-1) * self.scale[atype]
eye = np.eye(3)
eye = np.tile(eye, (nframes, nloc, 1, 1))
# (nframes, nloc, 3, 3)
bias = np.expand_dims(bias, axis=-1) * eye
out = out + bias
return {self.var_name: out}
10 changes: 10 additions & 0 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
from typing import (
Callable,
Expand All @@ -25,6 +26,9 @@
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,6 +127,12 @@ def serialize(self) -> dict:
data["c_differentiable"] = self.c_differentiable
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

def output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
Expand Down
10 changes: 10 additions & 0 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE
Expand Down Expand Up @@ -140,6 +143,12 @@ def serialize(self) -> dict:
data["atom_ener"] = self.atom_ener
return data

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

def compute_output_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
Expand Down Expand Up @@ -241,6 +250,7 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)
Expand Down
4 changes: 0 additions & 4 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@
from deepmd.utils.finetune import (
change_energy_bias_lower,
)
from deepmd.utils.version import (
check_version_compatibility,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE
Expand Down Expand Up @@ -371,7 +368,6 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
variables = data.pop("@variables")
nets = data.pop("nets")
obj = cls(**data)
Expand Down
Loading

0 comments on commit 2ee8a3b

Please sign in to comment.