Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): force & virial #4251

Merged
merged 8 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,22 +222,42 @@ def call_lower(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
atomic_ret = self.atomic_model.forward_common_atomic(
model_predict = self.forward_common_atomic(
cc_ext,
extended_atype,
nlist,
mapping=mapping,
fparam=fp,
aparam=ap,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

def forward_common_atomic(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
):
atomic_ret = self.atomic_model.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
model_predict = fit_output_to_model_output(
return fit_output_to_model_output(
atomic_ret,
self.atomic_output_def(),
cc_ext,
extended_coord,
do_atomic_virial=do_atomic_virial,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

forward_lower = call_lower

Expand Down
85 changes: 79 additions & 6 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
get_deriv_name,
get_reduce_name,
)
Expand Down Expand Up @@ -47,6 +48,28 @@
return model_ret


def get_leading_dims(
vv: np.ndarray,
vdef: OutputVariableDef,
):
"""Get the dimensions of nf x nloc.

Parameters
----------
vv : np.ndarray
The input array from which to compute the leading dimensions.
vdef : OutputVariableDef
The output variable definition containing the shape to exclude from `vv`.

Returns
-------
list
A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions.
"""
vshape = vv.shape
return list(vshape[: (len(vshape) - len(vdef.shape))])


def communicate_extended_output(
model_ret: dict[str, np.ndarray],
model_output_def: ModelOutputDef,
Expand All @@ -57,6 +80,7 @@
local and ghost (extended) atoms to local atoms.

"""
xp = array_api_compat.get_namespace(mapping)
new_ret = {}
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
Expand All @@ -65,15 +89,64 @@
if vdef.reducible:
kk_redu = get_reduce_name(kk)
new_ret[kk_redu] = model_ret[kk_redu]
kk_derv_r, kk_derv_c = get_deriv_name(kk)
mldims = list(mapping.shape)
vldims = get_leading_dims(vv, vdef)
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name holders
new_ret[kk_derv_r] = None
if model_ret[kk_derv_r] is not None:
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
force = xp.zeros(
vldims + derv_r_ext_dims, dtype=vv.dtype, device=vv.device
)
# jax only
if array_api_compat.is_jax_array(force):
from deepmd.jax.common import (
scatter_sum,
)

force = scatter_sum(
force,
1,
mapping,
model_ret[kk_derv_r],
)
else:
raise NotImplementedError("Only JAX arrays are supported.")

Check warning on line 116 in deepmd/dpmodel/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/transform_output.py#L116

Added line #L116 was not covered by tests
new_ret[kk_derv_r] = force
else:
# name holders
new_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if model_ret[kk_derv_c] is not None:
derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005
mapping = xp.tile(
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
)
virial = xp.zeros(
vldims + derv_c_ext_dims, dtype=vv.dtype, device=vv.device
)
# jax only
if array_api_compat.is_jax_array(virial):
from deepmd.jax.common import (
scatter_sum,
)

virial = scatter_sum(
virial,
1,
mapping,
model_ret[kk_derv_c],
)
else:
raise NotImplementedError("Only JAX arrays are supported.")

Check warning on line 144 in deepmd/dpmodel/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/transform_output.py#L144

Added line #L144 was not covered by tests
new_ret[kk_derv_c] = virial
new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1)
else:
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if not do_atomic_virial:
# pop atomic virial, because it is not correctly calculated.
new_ret.pop(kk_derv_c)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _make_env_mat(
# nf x nloc x nnei x 3
diff = coord_r - coord_l
# nf x nloc x nnei
length = xp.linalg.vector_norm(diff, axis=-1, keepdims=True)
# the grad of JAX vector_norm is NaN at x=0
diff_ = xp.where(xp.abs(diff) < 1e-30, xp.full_like(diff, 1e-30), diff)
length = xp.linalg.vector_norm(diff_, axis=-1, keepdims=True)
# for index 0 nloc atom
length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype)
t0 = 1 / (length + protection)
Expand Down
10 changes: 10 additions & 0 deletions deepmd/jax/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,13 @@ def __dlpack__(self, *args, **kwargs):

def __dlpack_device__(self, *args, **kwargs):
return self.value.__dlpack_device__(*args, **kwargs)


def scatter_sum(input, dim, index: jnp.ndarray, src: jnp.ndarray) -> jnp.ndarray:
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
idx = jnp.arange(input.size, dtype=jnp.int64).reshape(input.shape)
new_idx = jnp.take_along_axis(idx, index, axis=dim).ravel()
shape = input.shape
input = input.ravel()
input = input.at[new_idx].add(src.ravel())
return input.reshape(shape)
1 change: 1 addition & 0 deletions deepmd/jax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)

jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_debug_nans", True)

__all__ = [
"jax",
Expand Down
104 changes: 104 additions & 0 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,110 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
)

from deepmd.dpmodel.model.base_model import (
make_base_model,
)
from deepmd.dpmodel.output_def import (
get_deriv_name,
get_reduce_name,
)
from deepmd.jax.env import (
jax,
jnp,
)

BaseModel = make_base_model()


def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
atomic_ret = self.atomic_model.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)
atomic_output_def = self.atomic_output_def()
model_predict = {}
for kk, vv in atomic_ret.items():
model_predict[kk] = vv
vdef = atomic_output_def[kk]
shap = vdef.shape
atom_axis = -(len(shap) + 1)
if vdef.reducible:
kk_redu = get_reduce_name(kk)
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis)
kk_derv_r, kk_derv_c = get_deriv_name(kk)
if vdef.c_differentiable:
size = 1
for ii in vdef.shape:
size *= ii

split_ff = []
split_vv = []
njzjz marked this conversation as resolved.
Show resolved Hide resolved
for ss in range(size):

def eval_output(
cc_ext,
extended_atype,
nlist,
mapping,
fparam,
aparam,
*,
_kk=kk,
_ss=ss,
_atom_axis=atom_axis,
):
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext[None, ...],
extended_atype[None, ...],
nlist[None, ...],
mapping=mapping[None, ...] if mapping is not None else None,
fparam=fparam[None, ...] if fparam is not None else None,
aparam=aparam[None, ...] if aparam is not None else None,
)
return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis)[_ss]

ffi = -jax.vmap(jax.grad(eval_output, argnums=0))(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
)
aviri = ffi[..., None] @ extended_coord[..., None, :]
ffi = ffi[..., None, :]
split_ff.append(ffi)
aviri = aviri[..., None, :]
split_vv.append(aviri)
out_lead_shape = list(extended_coord.shape[:-1]) + vdef.shape
extended_force = jnp.concat(split_ff, axis=-2).reshape(
*out_lead_shape, 3
)

njzjz marked this conversation as resolved.
Show resolved Hide resolved
model_predict[kk_derv_r] = extended_force
if vdef.c_differentiable:
assert vdef.r_differentiable
extended_virial = jnp.concat(split_vv, axis=-2).reshape(
*out_lead_shape, 9
)
# the correction sums to zero, which does not contribute to global virial
if do_atomic_virial:
raise NotImplementedError("Atomic virial is not implemented yet.")

Check warning on line 107 in deepmd/jax/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/model/base_model.py#L107

Added line #L107 was not covered by tests
# to [...,3,3] -> [...,9]
model_predict[kk_derv_c] = extended_virial
return model_predict
26 changes: 26 additions & 0 deletions deepmd/jax/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Optional,
)

from deepmd.dpmodel.model import EnergyModel as EnergyModelDP
Expand All @@ -10,8 +11,12 @@
from deepmd.jax.common import (
flax_module,
)
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.model.base_model import (
BaseModel,
forward_common_atomic,
)


Expand All @@ -22,3 +27,24 @@ def __setattr__(self, name: str, value: Any) -> None:
if name == "atomic_model":
value = DPAtomicModel.deserialize(value.serialize())
return super().__setattr__(name, value)

def forward_common_atomic(
self,
extended_coord: jnp.ndarray,
extended_atype: jnp.ndarray,
nlist: jnp.ndarray,
mapping: Optional[jnp.ndarray] = None,
fparam: Optional[jnp.ndarray] = None,
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
return forward_common_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
4 changes: 4 additions & 0 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
"INSTALLED_ARRAY_API_STRICT",
]

SKIP_FLAG = object()


class CommonTest(ABC):
data: ClassVar[dict]
Expand Down Expand Up @@ -362,6 +364,8 @@ def test_dp_consistent_with_ref(self):
data2 = dp_obj.serialize()
np.testing.assert_equal(data1, data2)
for rr1, rr2 in zip(ret1, ret2):
if rr1 is SKIP_FLAG or rr2 is SKIP_FLAG:
continue
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix):
{},
suffix=suffix,
)
return [ret["energy"], ret["atom_ener"]], {
return [ret["energy"], ret["atom_ener"], ret["force"], ret["virial"]], {
t_coord: coords,
t_type: atype,
t_natoms: natoms,
Expand Down
Loading
Loading