Skip to content

Commit

Permalink
support virial
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 25, 2024
1 parent d0b576f commit 004b89a
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 13 deletions.
53 changes: 46 additions & 7 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,20 @@ def get_leading_dims(
vv: np.ndarray,
vdef: OutputVariableDef,
):
"""Get the dimensions of nf x nloc."""
"""Get the dimensions of nf x nloc.
Parameters
----------
vv : np.ndarray
The input array from which to compute the leading dimensions.
vdef : OutputVariableDef
The output variable definition containing the shape to exclude from `vv`.
Returns
-------
list
A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions.
"""
vshape = vv.shape
return list(vshape[: (len(vshape) - len(vdef.shape))])

Expand All @@ -76,11 +89,11 @@ def communicate_extended_output(
if vdef.reducible:
kk_redu = get_reduce_name(kk)
new_ret[kk_redu] = model_ret[kk_redu]
kk_derv_r, kk_derv_c = get_deriv_name(kk)
mldims = list(mapping.shape)
vldims = get_leading_dims(vv, vdef)
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
if model_ret[kk_derv_r] is not None:
mldims = list(mapping.shape)
vldims = get_leading_dims(vv, vdef)
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
Expand Down Expand Up @@ -109,9 +122,35 @@ def communicate_extended_output(
new_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if model_ret[kk_derv_c] is not None:
derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005
mapping = xp.tile(
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
)
virial = xp.zeros(
vldims + derv_c_ext_dims, dtype=vv.dtype, device=vv.device
)
# jax only
if array_api_compat.is_jax_array(virial):
from deepmd.jax.env import (
jnp,
)

v_idx = xp.arange(virial.size, dtype=xp.int64).reshape(
virial.shape
)
new_idx = jnp.take_along_axis(v_idx, mapping, axis=1).ravel()
v_shape = virial.shape
virial = virial.ravel()
virial = virial.at[new_idx].add(model_ret[kk_derv_c].ravel())
virial = virial.reshape(v_shape)
else:
raise NotImplementedError("Only JAX arrays are supported.")
new_ret[kk_derv_c] = virial
new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1)
else:
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if not do_atomic_virial:
# pop atomic virial, because it is not correctly calculated.
new_ret.pop(kk_derv_c)
Expand Down
19 changes: 16 additions & 3 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def forward_common_atomic(
size *= ii

split_ff = []
split_vv = []
for ss in range(size):

def eval_output(
Expand All @@ -76,13 +77,25 @@ def eval_output(
fparam,
aparam,
)
aviri = ffi[..., None] @ extended_coord[..., None, :]
ffi = ffi[..., None, :]
split_ff.append(ffi)
aviri = aviri[..., None, :]
split_vv.append(aviri)
out_lead_shape = list(extended_coord.shape[:-1]) + vdef.shape
ff = jnp.concatenate(split_ff, axis=-2).reshape(*out_lead_shape, 3)
extended_force = jnp.concat(split_ff, axis=-2).reshape(
*out_lead_shape, 3
)

model_predict[kk_derv_r] = ff
model_predict[kk_derv_r] = extended_force
if vdef.c_differentiable:
assert vdef.r_differentiable
model_predict[kk_derv_c] = None
extended_virial = jnp.concat(split_vv, axis=-2).reshape(
*out_lead_shape, 9
)
# the correction sums to zero, which does not contribute to global virial
if do_atomic_virial:
raise NotImplementedError("Atomic virial is not implemented yet.")
# to [...,3,3] -> [...,9]
model_predict[kk_derv_c] = extended_virial
return model_predict
2 changes: 1 addition & 1 deletion source/tests/consistent/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix):
{},
suffix=suffix,
)
return [ret["energy"], ret["atom_ener"], ret["force"]], {
return [ret["energy"], ret["atom_ener"], ret["force"], ret["virial"]], {
t_coord: coords,
t_type: atype,
t_natoms: natoms,
Expand Down
11 changes: 9 additions & 2 deletions source/tests/consistent/model/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,19 +211,26 @@ def eval_jax(self, jax_obj: Any) -> Any:
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
# shape not matched. ravel...
if backend is self.RefBackend.DP:
return (ret["energy_redu"].ravel(), ret["energy"].ravel(), SKIP_FLAG)
return (
ret["energy_redu"].ravel(),
ret["energy"].ravel(),
SKIP_FLAG,
SKIP_FLAG,
)
elif backend is self.RefBackend.PT:
return (
ret["energy"].ravel(),
ret["atom_energy"].ravel(),
ret["force"].ravel(),
ret["virial"].ravel(),
)
elif backend is self.RefBackend.TF:
return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel())
return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel())
elif backend is self.RefBackend.JAX:
return (
ret["energy_redu"].ravel(),
ret["energy"].ravel(),
ret["energy_derv_r"].ravel(),
ret["energy_derv_c_redu"].ravel(),
)
raise ValueError(f"Unknown backend: {backend}")

0 comments on commit 004b89a

Please sign in to comment.