Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): force & virial #4251

Merged
merged 8 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,22 +222,42 @@ def call_lower(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
atomic_ret = self.atomic_model.forward_common_atomic(
model_predict = self.forward_common_atomic(
cc_ext,
extended_atype,
nlist,
mapping=mapping,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

def forward_common_atomic(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
):
atomic_ret = self.atomic_model.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
model_predict = fit_output_to_model_output(
return fit_output_to_model_output(
atomic_ret,
self.atomic_output_def(),
cc_ext,
extended_coord,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

forward_lower = call_lower

Expand Down
42 changes: 40 additions & 2 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
get_deriv_name,
get_reduce_name,
)
Expand Down Expand Up @@ -47,6 +48,15 @@ def fit_output_to_model_output(
return model_ret


def get_leading_dims(
vv: np.ndarray,
vdef: OutputVariableDef,
):
"""Get the dimensions of nf x nloc."""
vshape = vv.shape
return list(vshape[: (len(vshape) - len(vdef.shape))])
njzjz marked this conversation as resolved.
Show resolved Hide resolved


def communicate_extended_output(
model_ret: dict[str, np.ndarray],
model_output_def: ModelOutputDef,
Expand All @@ -57,6 +67,7 @@ def communicate_extended_output(
local and ghost (extended) atoms to local atoms.

"""
xp = array_api_compat.get_namespace(mapping)
new_ret = {}
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
Expand All @@ -67,8 +78,35 @@ def communicate_extended_output(
new_ret[kk_redu] = model_ret[kk_redu]
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name holders
new_ret[kk_derv_r] = None
if model_ret[kk_derv_r] is not None:
mldims = list(mapping.shape)
vldims = get_leading_dims(vv, vdef)
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
force = xp.zeros(
vldims + derv_r_ext_dims, dtype=vv.dtype, device=vv.device
)
# jax only
if array_api_compat.is_jax_array(force):
from deepmd.jax.env import (
jnp,
)

f_idx = xp.arange(force.size, dtype=xp.int64).reshape(
force.shape
)
new_idx = jnp.take_along_axis(f_idx, mapping, axis=1).ravel()
f_shape = force.shape
force = force.ravel()
force = force.at[new_idx].add(model_ret[kk_derv_r].ravel())
force = force.reshape(f_shape)
else:
raise NotImplementedError("Only JAX arrays are supported.")
njzjz marked this conversation as resolved.
Show resolved Hide resolved
new_ret[kk_derv_r] = force
else:
# name holders
new_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
Expand Down
2 changes: 2 additions & 0 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def _make_env_mat(
# nf x nloc x nnei x 3
diff = coord_r - coord_l
# nf x nloc x nnei
# the grad of JAX vector_norm is NaN at x=0
diff = xp.where(xp.abs(diff) < 1e-30, xp.full_like(diff, 1e-30), diff)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
length = xp.linalg.vector_norm(diff, axis=-1, keepdims=True)
# for index 0 nloc atom
length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype)
Expand Down
1 change: 1 addition & 0 deletions deepmd/jax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)

jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_debug_nans", True)

__all__ = [
"jax",
Expand Down
82 changes: 82 additions & 0 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,88 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
)

from deepmd.dpmodel.model.base_model import (
make_base_model,
)
from deepmd.dpmodel.output_def import (
get_deriv_name,
get_reduce_name,
)
from deepmd.jax.env import (
jax,
jnp,
)

BaseModel = make_base_model()


def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
atomic_ret = self.atomic_model.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
atomic_output_def = self.atomic_output_def()
model_predict = {}
for kk, vv in atomic_ret.items():
model_predict[kk] = vv
vdef = atomic_output_def[kk]
shap = vdef.shape
atom_axis = -(len(shap) + 1)
if vdef.reducible:
kk_redu = get_reduce_name(kk)
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis)
kk_derv_r, kk_derv_c = get_deriv_name(kk)
if vdef.c_differentiable:
size = 1
for ii in vdef.shape:
size *= ii

split_ff = []
for ss in range(size):

def eval_output(
cc_ext, extended_atype, nlist, mapping, fparam, aparam
):
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext[None, ...],
extended_atype[None, ...],
nlist[None, ...],
mapping=mapping[None, ...] if mapping is not None else None,
fparam=fparam[None, ...] if fparam is not None else None,
aparam=aparam[None, ...] if aparam is not None else None,
)
return jnp.sum(atomic_ret[kk][0], axis=atom_axis)[ss]
njzjz marked this conversation as resolved.
Show resolved Hide resolved
njzjz marked this conversation as resolved.
Show resolved Hide resolved

ffi = -jax.vmap(jax.grad(eval_output, argnums=0))(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
)
ffi = ffi[..., None, :]
split_ff.append(ffi)
out_lead_shape = list(extended_coord.shape[:-1]) + vdef.shape
ff = jnp.concatenate(split_ff, axis=-2).reshape(*out_lead_shape, 3)

model_predict[kk_derv_r] = ff
if vdef.c_differentiable:
assert vdef.r_differentiable
model_predict[kk_derv_c] = None
return model_predict
26 changes: 26 additions & 0 deletions deepmd/jax/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Optional,
)

from deepmd.dpmodel.model import EnergyModel as EnergyModelDP
Expand All @@ -10,8 +11,12 @@
from deepmd.jax.common import (
flax_module,
)
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.model.base_model import (
BaseModel,
forward_common_atomic,
)


Expand All @@ -22,3 +27,24 @@ def __setattr__(self, name: str, value: Any) -> None:
if name == "atomic_model":
value = DPAtomicModel.deserialize(value.serialize())
return super().__setattr__(name, value)

def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
return forward_common_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
4 changes: 4 additions & 0 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
"INSTALLED_ARRAY_API_STRICT",
]

SKIP_FLAG = object()


class CommonTest(ABC):
data: ClassVar[dict]
Expand Down Expand Up @@ -362,6 +364,8 @@ def test_dp_consistent_with_ref(self):
data2 = dp_obj.serialize()
np.testing.assert_equal(data1, data2)
for rr1, rr2 in zip(ret1, ret2):
if rr1 is SKIP_FLAG or rr2 is SKIP_FLAG:
continue
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix):
{},
suffix=suffix,
)
return [ret["energy"], ret["atom_ener"]], {
return [ret["energy"], ret["atom_ener"], ret["force"]], {
njzjz marked this conversation as resolved.
Show resolved Hide resolved
t_coord: coords,
t_type: atype,
t_natoms: natoms,
Expand Down
32 changes: 28 additions & 4 deletions source/tests/consistent/model/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
SKIP_FLAG,
CommonTest,
parameterized,
)
Expand Down Expand Up @@ -94,6 +95,21 @@ def data(self) -> dict:
jax_class = EnergyModelJAX
args = model_args()

def get_reference_backend(self):
"""Get the reference backend.

We need a reference backend that can reproduce forces.
"""
if not self.skip_pt:
return self.RefBackend.PT
if not self.skip_tf:
return self.RefBackend.TF
if not self.skip_jax:
return self.RefBackend.JAX
if not self.skip_dp:
return self.RefBackend.DP
raise ValueError("No available reference")

@property
def skip_tf(self):
return (
Expand Down Expand Up @@ -195,11 +211,19 @@ def eval_jax(self, jax_obj: Any) -> Any:
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
# shape not matched. ravel...
if backend is self.RefBackend.DP:
return (ret["energy_redu"].ravel(), ret["energy"].ravel())
return (ret["energy_redu"].ravel(), ret["energy"].ravel(), SKIP_FLAG)
elif backend is self.RefBackend.PT:
return (ret["energy"].ravel(), ret["atom_energy"].ravel())
return (
ret["energy"].ravel(),
ret["atom_energy"].ravel(),
ret["force"].ravel(),
)
elif backend is self.RefBackend.TF:
return (ret[0].ravel(), ret[1].ravel())
return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel())
elif backend is self.RefBackend.JAX:
return (ret["energy_redu"].ravel(), ret["energy"].ravel())
return (
ret["energy_redu"].ravel(),
ret["energy"].ravel(),
ret["energy_derv_r"].ravel(),
)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Unknown backend: {backend}")
Loading