Skip to content

Commit

Permalink
feat(jax/array-api): energy fitting (#4204)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced a fitting module for energy models using JAX, enhancing
compatibility with different array backends.
- Added `AtomExcludeMask` class for improved attribute handling in
exclusion masks.

- **Improvements**
- Updated serialization and array handling methods for better
integration with array APIs.
- Enhanced testing capabilities for energy fitting with support for
different backends.

- **Documentation**
- Added SPDX license identifier to relevant files for licensing clarity.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Oct 13, 2024
1 parent c10bc3c commit 8279cca
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 22 deletions.
48 changes: 29 additions & 19 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
DEFAULT_PRECISION,
NativeOP,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils import (
AtomExcludeMask,
FittingNet,
Expand Down Expand Up @@ -283,11 +287,11 @@ def serialize(self) -> dict:
"exclude_types": self.exclude_types,
"nets": self.nets.serialize(),
"@variables": {
"bias_atom_e": self.bias_atom_e,
"fparam_avg": self.fparam_avg,
"fparam_inv_std": self.fparam_inv_std,
"aparam_avg": self.aparam_avg,
"aparam_inv_std": self.aparam_inv_std,
"bias_atom_e": to_numpy_array(self.bias_atom_e),
"fparam_avg": to_numpy_array(self.fparam_avg),
"fparam_inv_std": to_numpy_array(self.fparam_inv_std),
"aparam_avg": to_numpy_array(self.aparam_avg),
"aparam_inv_std": to_numpy_array(self.aparam_inv_std),
},
"type_map": self.type_map,
# not supported
Expand Down Expand Up @@ -344,6 +348,7 @@ def _call_common(
The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam`
"""
xp = array_api_compat.array_namespace(descriptor, atype)
nf, nloc, nd = descriptor.shape
net_dim_out = self._net_out_dim()
# check input dim
Expand All @@ -359,7 +364,7 @@ def _call_common(
# we consider it as always zero for convenience.
# Needs a compute_input_stats for vaccum passed from the
# descriptor.
xx_zeros = np.zeros_like(xx)
xx_zeros = xp.zeros_like(xx)
else:
xx_zeros = None
# check fparam dim, concate to input descriptor
Expand All @@ -371,13 +376,15 @@ def _call_common(
"which is not consistent with {self.numb_fparam}.",
)
fparam = (fparam - self.fparam_avg) * self.fparam_inv_std
fparam = np.tile(fparam.reshape([nf, 1, self.numb_fparam]), [1, nloc, 1])
xx = np.concatenate(
fparam = xp.tile(
xp.reshape(fparam, [nf, 1, self.numb_fparam]), (1, nloc, 1)
)
xx = xp.concat(
[xx, fparam],
axis=-1,
)
if xx_zeros is not None:
xx_zeros = np.concatenate(
xx_zeros = xp.concat(
[xx_zeros, fparam],
axis=-1,
)
Expand All @@ -389,24 +396,24 @@ def _call_common(
"get an input aparam of dim {aparam.shape[-1]}, ",
"which is not consistent with {self.numb_aparam}.",
)
aparam = aparam.reshape([nf, nloc, self.numb_aparam])
aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam])
aparam = (aparam - self.aparam_avg) * self.aparam_inv_std
xx = np.concatenate(
xx = xp.concat(
[xx, aparam],
axis=-1,
)
if xx_zeros is not None:
xx_zeros = np.concatenate(
xx_zeros = xp.concat(
[xx_zeros, aparam],
axis=-1,
)

# calcualte the prediction
if not self.mixed_types:
outs = np.zeros([nf, nloc, net_dim_out]) # pylint: disable=no-explicit-dtype
outs = xp.zeros([nf, nloc, net_dim_out]) # pylint: disable=no-explicit-dtype
for type_i in range(self.ntypes):
mask = np.tile(
(atype == type_i).reshape([nf, nloc, 1]), [1, 1, net_dim_out]
mask = xp.tile(
xp.reshape((atype == type_i), [nf, nloc, 1]), (1, 1, net_dim_out)
)
atom_property = self.nets[(type_i,)](xx)
if self.remove_vaccum_contribution is not None and not (
Expand All @@ -415,15 +422,18 @@ def _call_common(
):
assert xx_zeros is not None
atom_property -= self.nets[(type_i,)](xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i]
atom_property = atom_property * mask
atom_property = atom_property + self.bias_atom_e[type_i, ...]
atom_property = atom_property * xp.astype(mask, atom_property.dtype)
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
else:
outs = self.nets[()](xx) + self.bias_atom_e[atype]
outs = self.nets[()](xx) + xp.reshape(
xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0),
[nf, nloc, net_dim_out],
)
if xx_zeros is not None:
outs -= self.nets[()](xx_zeros)
# nf x nloc
exclude_mask = self.emask.build_type_exclude_mask(atype)
# nf x nloc x nod
outs = outs * exclude_mask[:, :, None]
outs = outs * xp.astype(exclude_mask[:, :, None], outs.dtype)
return {self.var_name: outs}
8 changes: 5 additions & 3 deletions deepmd/dpmodel/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ def __init__(
):
self.ntypes = ntypes
self.exclude_types = exclude_types
self.type_mask = np.array(
type_mask = np.array(
[1 if tt_i not in self.exclude_types else 0 for tt_i in range(ntypes)],
dtype=np.int32,
)
# (ntypes)
self.type_mask = self.type_mask.reshape([-1])
self.type_mask = type_mask.reshape([-1])

def get_exclude_types(self):
return self.exclude_types
Expand Down Expand Up @@ -52,7 +52,9 @@ def build_type_exclude_mask(
"""
xp = array_api_compat.array_namespace(atype)
nf, natom = atype.shape
return xp.reshape(self.type_mask[atype], (nf, natom))
return xp.reshape(
xp.take(self.type_mask, xp.reshape(atype, [-1]), axis=0), (nf, natom)
)


class PairExcludeMask:
Expand Down
1 change: 1 addition & 0 deletions deepmd/jax/fitting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
39 changes: 39 additions & 0 deletions deepmd/jax/fitting/fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
from deepmd.jax.common import (
flax_module,
to_jax_array,
)
from deepmd.jax.utils.exclude_mask import (
AtomExcludeMask,
)
from deepmd.jax.utils.network import (
NetworkCollection,
)


def setattr_for_general_fitting(name: str, value: Any) -> Any:
if name in {
"bias_atom_e",
"fparam_avg",
"fparam_inv_std",
"aparam_avg",
"aparam_inv_std",
}:
value = to_jax_array(value)
elif name == "emask":
value = AtomExcludeMask(value.ntypes, value.exclude_types)
elif name == "nets":
value = NetworkCollection.deserialize(value.serialize())
return value


@flax_module
class EnergyFittingNet(EnergyFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)
9 changes: 9 additions & 0 deletions deepmd/jax/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,22 @@
Any,
)

from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
from deepmd.jax.common import (
flax_module,
to_jax_array,
)


@flax_module
class AtomExcludeMask(AtomExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"type_mask"}:
value = to_jax_array(value)
return super().__setattr__(name, value)


@flax_module
class PairExcludeMask(PairExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
Expand Down
1 change: 1 addition & 0 deletions source/tests/array_api_strict/fitting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
38 changes: 38 additions & 0 deletions source/tests/array_api_strict/fitting/fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP

from ..common import (
to_array_api_strict_array,
)
from ..utils.exclude_mask import (
AtomExcludeMask,
)
from ..utils.network import (
NetworkCollection,
)


def setattr_for_general_fitting(name: str, value: Any) -> Any:
if name in {
"bias_atom_e",
"fparam_avg",
"fparam_inv_std",
"aparam_avg",
"aparam_inv_std",
}:
value = to_array_api_strict_array(value)
elif name == "emask":
value = AtomExcludeMask(value.ntypes, value.exclude_types)
elif name == "nets":
value = NetworkCollection.deserialize(value.serialize())
return value


class EnergyFittingNet(EnergyFittingNetDP):
def __setattr__(self, name: str, value: Any) -> None:
value = setattr_for_general_fitting(name, value)
return super().__setattr__(name, value)
8 changes: 8 additions & 0 deletions source/tests/array_api_strict/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
Any,
)

from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP

from ..common import (
to_array_api_strict_array,
)


class AtomExcludeMask(AtomExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"type_mask"}:
value = to_array_api_strict_array(value)
return super().__setattr__(name, value)


class PairExcludeMask(PairExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"type_mask"}:
Expand Down
67 changes: 67 additions & 0 deletions source/tests/consistent/fitting/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -36,6 +38,22 @@
fitting_ener,
)

if INSTALLED_JAX:
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.fitting.fitting import EnergyFittingNet as EnerFittingJAX
else:
EnerFittingJAX = object
if INSTALLED_ARRAY_API_STRICT:
import array_api_strict

from ...array_api_strict.fitting.fitting import (
EnergyFittingNet as EnerFittingStrict,
)
else:
EnerFittingStrict = None


@parameterized(
(True, False), # resnet_dt
Expand Down Expand Up @@ -74,9 +92,25 @@ def skip_pt(self) -> bool:
) = self.param
return CommonTest.skip_pt

skip_jax = not INSTALLED_JAX

@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
# TypeError: The array_api_strict namespace does not support the dtype 'bfloat16'
return not INSTALLED_ARRAY_API_STRICT or precision == "bfloat16"

tf_class = EnerFittingTF
dp_class = EnerFittingDP
pt_class = EnerFittingPT
jax_class = EnerFittingJAX
array_api_strict_class = EnerFittingStrict
args = fitting_ener()

def setUp(self):
Expand Down Expand Up @@ -157,6 +191,39 @@ def eval_dp(self, dp_obj: Any) -> Any:
fparam=self.fparam if numb_fparam else None,
)["energy"]

def eval_jax(self, jax_obj: Any) -> Any:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
return np.asarray(
jax_obj(
jnp.asarray(self.inputs),
jnp.asarray(self.atype.reshape(1, -1)),
fparam=jnp.asarray(self.fparam) if numb_fparam else None,
)["energy"]
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
array_api_strict.set_array_api_strict_flags(api_version="2023.12")
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
atom_ener,
) = self.param
return np.asarray(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None,
)["energy"]
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
if backend == self.RefBackend.TF:
# shape is not same
Expand Down

0 comments on commit 8279cca

Please sign in to comment.