Skip to content

Commit

Permalink
feat(jax/array-api): dipole/polarizability fitting (#4278)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced `DipoleFittingNet` and `PolarFittingNet` classes for
enhanced fitting functionality.
- Expanded support for JAX as a backend for fitting tensors, alongside
existing TensorFlow and PyTorch support.

- **Bug Fixes**
- Improved error handling and parameter validation in the
`DipoleFitting` and `PolarFitting` classes.

- **Documentation**
- Updated documentation to reflect JAX as a supported backend for
fitting tensors.

- **Tests**
- Enhanced testing framework to support evaluations with JAX and Array
API Strict, including new test methods and properties.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Oct 31, 2024
1 parent 737f7c8 commit 8e27d2f
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 35 deletions.
10 changes: 7 additions & 3 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
Expand Down Expand Up @@ -207,16 +208,19 @@ def call(
The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam`
"""
xp = array_api_compat.array_namespace(descriptor, atype)
nframes, nloc, _ = descriptor.shape
assert gr is not None, "Must provide the rotation matrix for dipole fitting."
# (nframes, nloc, m1)
out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
# (nframes * nloc, 1, m1)
out = out.reshape(-1, 1, self.embedding_width)
out = xp.reshape(out, (-1, 1, self.embedding_width))
# (nframes * nloc, m1, 3)
gr = gr.reshape(nframes * nloc, -1, 3)
gr = xp.reshape(gr, (nframes * nloc, -1, 3))
# (nframes, nloc, 3)
out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3)
# out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3)
out = out @ gr
out = xp.reshape(out, (nframes, nloc, 3))
return {self.var_name: out}
71 changes: 41 additions & 30 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.common import (
Expand All @@ -14,6 +15,9 @@
from deepmd.dpmodel import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
Expand Down Expand Up @@ -124,23 +128,18 @@ def __init__(

self.embedding_width = embedding_width
self.fit_diag = fit_diag
self.scale = scale
if self.scale is None:
self.scale = [1.0 for _ in range(ntypes)]
if scale is None:
scale = [1.0 for _ in range(ntypes)]
else:
if isinstance(self.scale, list):
assert (
len(self.scale) == ntypes
), "Scale should be a list of length ntypes."
elif isinstance(self.scale, float):
self.scale = [self.scale for _ in range(ntypes)]
if isinstance(scale, list):
assert len(scale) == ntypes, "Scale should be a list of length ntypes."
elif isinstance(scale, float):
scale = [scale for _ in range(ntypes)]
else:
raise ValueError(
"Scale must be a list of float of length ntypes or a float."
)
self.scale = np.array(self.scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(
ntypes, 1
)
self.scale = np.array(scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(ntypes, 1)
self.shift_diag = shift_diag
self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION)
super().__init__(
Expand Down Expand Up @@ -192,8 +191,8 @@ def serialize(self) -> dict:
data["embedding_width"] = self.embedding_width
data["fit_diag"] = self.fit_diag
data["shift_diag"] = self.shift_diag
data["@variables"]["scale"] = self.scale
data["@variables"]["constant_matrix"] = self.constant_matrix
data["@variables"]["scale"] = to_numpy_array(self.scale)
data["@variables"]["constant_matrix"] = to_numpy_array(self.constant_matrix)
return data

@classmethod
Expand Down Expand Up @@ -276,6 +275,7 @@ def call(
The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam`
"""
xp = array_api_compat.array_namespace(descriptor, atype)
nframes, nloc, _ = descriptor.shape
assert (
gr is not None
Expand All @@ -284,28 +284,39 @@ def call(
out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
out = out * self.scale[atype]
# out = out * self.scale[atype, ...]
scale_atype = xp.reshape(
xp.take(self.scale, xp.reshape(atype, [-1]), axis=0), (*atype.shape, 1)
)
out = out * scale_atype
# (nframes * nloc, m1, 3)
gr = gr.reshape(nframes * nloc, -1, 3)
gr = xp.reshape(gr, (nframes * nloc, -1, 3))

if self.fit_diag:
out = out.reshape(-1, self.embedding_width)
out = np.einsum("ij,ijk->ijk", out, gr)
out = xp.reshape(out, (-1, self.embedding_width))
# out = np.einsum("ij,ijk->ijk", out, gr)
out = out[:, :, None] * gr
else:
out = out.reshape(-1, self.embedding_width, self.embedding_width)
out = (out + np.transpose(out, axes=(0, 2, 1))) / 2
out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3)
out = np.einsum(
"bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out
) # (nframes * nloc, 3, 3)
out = out.reshape(nframes, nloc, 3, 3)
out = xp.reshape(out, (-1, self.embedding_width, self.embedding_width))
out = (out + xp.matrix_transpose(out)) / 2
# out = np.einsum("bim,bmj->bij", out, gr) # (nframes * nloc, m1, 3)
out = out @ gr
# out = np.einsum(
# "bim,bmj->bij", np.transpose(gr, axes=(0, 2, 1)), out
# ) # (nframes * nloc, 3, 3)
out = xp.matrix_transpose(gr) @ out
out = xp.reshape(out, (nframes, nloc, 3, 3))
if self.shift_diag:
bias = self.constant_matrix[atype]
# bias = self.constant_matrix[atype]
bias = xp.reshape(
xp.take(self.constant_matrix, xp.reshape(atype, [-1]), axis=0),
(nframes, nloc),
)
# (nframes, nloc, 1)
bias = np.expand_dims(bias, axis=-1) * self.scale[atype]
eye = np.eye(3, dtype=descriptor.dtype)
eye = np.tile(eye, (nframes, nloc, 1, 1))
bias = bias[..., None] * scale_atype
eye = xp.eye(3, dtype=descriptor.dtype)
eye = xp.tile(eye, (nframes, nloc, 1, 1))
# (nframes, nloc, 3, 3)
bias = np.expand_dims(bias, axis=-1) * eye
bias = bias[..., None] * eye
out = out + bias
return {"polarizability": out}
4 changes: 4 additions & 0 deletions deepmd/jax/fitting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.jax.fitting.fitting import (
DipoleFittingNet,
DOSFittingNet,
EnergyFittingNet,
PolarFittingNet,
)

__all__ = [
"EnergyFittingNet",
"DOSFittingNet",
"DipoleFittingNet",
"PolarFittingNet",
]
27 changes: 27 additions & 0 deletions deepmd/jax/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
Any,
)

from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingNetDP
from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
from deepmd.dpmodel.fitting.polarizability_fitting import (
PolarFitting as PolarFittingNetDP,
)
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
Expand Down Expand Up @@ -53,3 +57,26 @@ class DOSFittingNet(DOSFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)


@BaseFitting.register("dipole")
@flax_module
class DipoleFittingNet(DipoleFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)


@BaseFitting.register("polar")
@flax_module
class PolarFittingNet(PolarFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
if name in {
"scale",
"constant_matrix",
}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
return super().__setattr__(name, value)
4 changes: 2 additions & 2 deletions doc/model/train-fitting-tensor.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Fit `tensor` like `Dipole` and `Polarizability` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Fit `tensor` like `Dipole` and `Polarizability` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

Unlike `energy`, which is a scalar, one may want to fit some high dimensional physical quantity, like `dipole` (vector) and `polarizability` (matrix, shorted as `polar`). Deep Potential has provided different APIs to do this. In this example, we will show you how to train a model to fit a water system. A complete training input script of the examples can be found in
Expand Down
21 changes: 21 additions & 0 deletions source/tests/array_api_strict/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
Any,
)

from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingNetDP
from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
from deepmd.dpmodel.fitting.polarizability_fitting import (
PolarFitting as PolarFittingNetDP,
)

from ..common import (
to_array_api_strict_array,
Expand Down Expand Up @@ -43,3 +47,20 @@ class DOSFittingNet(DOSFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)


class DipoleFittingNet(DipoleFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)


class PolarFittingNet(PolarFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
if name in {
"scale",
"constant_matrix",
}:
value = to_array_api_strict_array(value)
return super().__setattr__(name, value)
41 changes: 41 additions & 0 deletions source/tests/consistent/fitting/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -32,6 +34,21 @@
from deepmd.tf.fit.dipole import DipoleFittingSeA as DipoleFittingTF
else:
DipoleFittingTF = object
if INSTALLED_JAX:
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.fitting.fitting import DipoleFittingNet as DipoleFittingJAX
else:
DipoleFittingJAX = object
if INSTALLED_ARRAY_API_STRICT:
import array_api_strict

from ...array_api_strict.fitting.fitting import (
DipoleFittingNet as DipoleFittingArrayAPIStrict,
)
else:
DipoleFittingArrayAPIStrict = object
from deepmd.utils.argcheck import (
fitting_dipole,
)
Expand Down Expand Up @@ -69,7 +86,11 @@ def skip_pt(self) -> bool:
tf_class = DipoleFittingTF
dp_class = DipoleFittingDP
pt_class = DipoleFittingPT
jax_class = DipoleFittingJAX
array_api_strict_class = DipoleFittingArrayAPIStrict
args = fitting_dipole()
skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

def setUp(self):
CommonTest.setUp(self)
Expand Down Expand Up @@ -143,6 +164,26 @@ def eval_dp(self, dp_obj: Any) -> Any:
None,
)["dipole"]

def eval_jax(self, jax_obj: Any) -> Any:
return np.asarray(
jax_obj(
jnp.asarray(self.inputs),
jnp.asarray(self.atype.reshape(1, -1)),
jnp.asarray(self.gr),
None,
)["dipole"]
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return np.asarray(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
array_api_strict.asarray(self.gr),
None,
)["dipole"]
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
if backend == self.RefBackend.TF:
# shape is not same
Expand Down
41 changes: 41 additions & 0 deletions source/tests/consistent/fitting/test_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -32,6 +34,21 @@
from deepmd.tf.fit.polar import PolarFittingSeA as PolarFittingTF
else:
PolarFittingTF = object
if INSTALLED_JAX:
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.fitting.fitting import PolarFittingNet as PolarFittingJAX
else:
PolarFittingJAX = object
if INSTALLED_ARRAY_API_STRICT:
import array_api_strict

from ...array_api_strict.fitting.fitting import (
PolarFittingNet as PolarFittingArrayAPIStrict,
)
else:
PolarFittingArrayAPIStrict = object
from deepmd.utils.argcheck import (
fitting_polar,
)
Expand Down Expand Up @@ -69,7 +86,11 @@ def skip_pt(self) -> bool:
tf_class = PolarFittingTF
dp_class = PolarFittingDP
pt_class = PolarFittingPT
jax_class = PolarFittingJAX
array_api_strict_class = PolarFittingArrayAPIStrict
args = fitting_polar()
skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

def setUp(self):
CommonTest.setUp(self)
Expand Down Expand Up @@ -143,6 +164,26 @@ def eval_dp(self, dp_obj: Any) -> Any:
None,
)["polarizability"]

def eval_jax(self, jax_obj: Any) -> Any:
return np.asarray(
jax_obj(
jnp.asarray(self.inputs),
jnp.asarray(self.atype.reshape(1, -1)),
jnp.asarray(self.gr),
None,
)["polarizability"]
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return np.asarray(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
array_api_strict.asarray(self.gr),
None,
)["polarizability"]
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
if backend == self.RefBackend.TF:
# shape is not same
Expand Down

0 comments on commit 8e27d2f

Please sign in to comment.