From 0517b593066b7810563c207e59b8e0e088289f9c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 24 Oct 2024 22:02:01 -0400 Subject: [PATCH 1/8] feat(jax): force Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/model/make_model.py | 30 ++++++-- deepmd/dpmodel/model/transform_output.py | 42 ++++++++++- deepmd/dpmodel/utils/env_mat.py | 2 + deepmd/jax/env.py | 1 + deepmd/jax/model/base_model.py | 82 ++++++++++++++++++++++ deepmd/jax/model/ener_model.py | 26 +++++++ source/tests/consistent/common.py | 4 ++ source/tests/consistent/model/common.py | 2 +- source/tests/consistent/model/test_ener.py | 32 +++++++-- 9 files changed, 209 insertions(+), 12 deletions(-) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index dc90f10da7..4007035f14 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -222,22 +222,42 @@ def call_lower( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam - atomic_ret = self.atomic_model.forward_common_atomic( + model_predict = self.forward_common_atomic( cc_ext, extended_atype, nlist, mapping=mapping, fparam=fp, aparam=ap, + do_atomic_virial=do_atomic_virial, + ) + model_predict = self.output_type_cast(model_predict, input_prec) + return model_predict + + def forward_common_atomic( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlist: np.ndarray, + mapping: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + do_atomic_virial: bool = False, + ): + atomic_ret = self.atomic_model.forward_common_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, ) - model_predict = fit_output_to_model_output( + return fit_output_to_model_output( atomic_ret, self.atomic_output_def(), - cc_ext, + extended_coord, do_atomic_virial=do_atomic_virial, ) - model_predict = self.output_type_cast(model_predict, input_prec) - return model_predict forward_lower = call_lower diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index 928c33f3bd..63b9818397 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -9,6 +9,7 @@ from deepmd.dpmodel.output_def import ( FittingOutputDef, ModelOutputDef, + OutputVariableDef, get_deriv_name, get_reduce_name, ) @@ -47,6 +48,15 @@ def fit_output_to_model_output( return model_ret +def get_leading_dims( + vv: np.ndarray, + vdef: OutputVariableDef, +): + """Get the dimensions of nf x nloc.""" + vshape = vv.shape + return list(vshape[: (len(vshape) - len(vdef.shape))]) + + def communicate_extended_output( model_ret: dict[str, np.ndarray], model_output_def: ModelOutputDef, @@ -57,6 +67,7 @@ def communicate_extended_output( local and ghost (extended) atoms to local atoms. """ + xp = array_api_compat.get_namespace(mapping) new_ret = {} for kk in model_output_def.keys_outp(): vv = model_ret[kk] @@ -67,8 +78,35 @@ def communicate_extended_output( new_ret[kk_redu] = model_ret[kk_redu] if vdef.r_differentiable: kk_derv_r, kk_derv_c = get_deriv_name(kk) - # name holders - new_ret[kk_derv_r] = None + if model_ret[kk_derv_r] is not None: + mldims = list(mapping.shape) + vldims = get_leading_dims(vv, vdef) + derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005 + mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims))) + mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims) + force = xp.zeros( + vldims + derv_r_ext_dims, dtype=vv.dtype, device=vv.device + ) + # jax only + if array_api_compat.is_jax_array(force): + from deepmd.jax.env import ( + jnp, + ) + + f_idx = xp.arange(force.size, dtype=xp.int64).reshape( + force.shape + ) + new_idx = jnp.take_along_axis(f_idx, mapping, axis=1).ravel() + f_shape = force.shape + force = force.ravel() + force = force.at[new_idx].add(model_ret[kk_derv_r].ravel()) + force = force.reshape(f_shape) + else: + raise NotImplementedError("Only JAX arrays are supported.") + new_ret[kk_derv_r] = force + else: + # name holders + new_ret[kk_derv_r] = None if vdef.c_differentiable: assert vdef.r_differentiable kk_derv_r, kk_derv_c = get_deriv_name(kk) diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index f4bc333a03..3cd729f553 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -61,6 +61,8 @@ def _make_env_mat( # nf x nloc x nnei x 3 diff = coord_r - coord_l # nf x nloc x nnei + # the grad of JAX vector_norm is NaN at x=0 + diff = xp.where(xp.abs(diff) < 1e-30, xp.full_like(diff, 1e-30), diff) length = xp.linalg.vector_norm(diff, axis=-1, keepdims=True) # for index 0 nloc atom length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype) diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py index 5a5a7f6bf0..ee11e17125 100644 --- a/deepmd/jax/env.py +++ b/deepmd/jax/env.py @@ -10,6 +10,7 @@ ) jax.config.update("jax_enable_x64", True) +# jax.config.update("jax_debug_nans", True) __all__ = [ "jax", diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index fee4855da3..3d4450468b 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -1,6 +1,88 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + from deepmd.dpmodel.model.base_model import ( make_base_model, ) +from deepmd.dpmodel.output_def import ( + get_deriv_name, + get_reduce_name, +) +from deepmd.jax.env import ( + jax, + jnp, +) BaseModel = make_base_model() + + +def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, +): + atomic_ret = self.atomic_model.forward_common_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + atomic_output_def = self.atomic_output_def() + model_predict = {} + for kk, vv in atomic_ret.items(): + model_predict[kk] = vv + vdef = atomic_output_def[kk] + shap = vdef.shape + atom_axis = -(len(shap) + 1) + if vdef.reducible: + kk_redu = get_reduce_name(kk) + model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis) + kk_derv_r, kk_derv_c = get_deriv_name(kk) + if vdef.c_differentiable: + size = 1 + for ii in vdef.shape: + size *= ii + + split_ff = [] + for ss in range(size): + + def eval_output( + cc_ext, extended_atype, nlist, mapping, fparam, aparam + ): + atomic_ret = self.atomic_model.forward_common_atomic( + cc_ext[None, ...], + extended_atype[None, ...], + nlist[None, ...], + mapping=mapping[None, ...] if mapping is not None else None, + fparam=fparam[None, ...] if fparam is not None else None, + aparam=aparam[None, ...] if aparam is not None else None, + ) + return jnp.sum(atomic_ret[kk][0], axis=atom_axis)[ss] + + ffi = -jax.vmap(jax.grad(eval_output, argnums=0))( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) + ffi = ffi[..., None, :] + split_ff.append(ffi) + out_lead_shape = list(extended_coord.shape[:-1]) + vdef.shape + ff = jnp.concatenate(split_ff, axis=-2).reshape(*out_lead_shape, 3) + + model_predict[kk_derv_r] = ff + if vdef.c_differentiable: + assert vdef.r_differentiable + model_predict[kk_derv_c] = None + return model_predict diff --git a/deepmd/jax/model/ener_model.py b/deepmd/jax/model/ener_model.py index 79c5a29e88..b1bf568544 100644 --- a/deepmd/jax/model/ener_model.py +++ b/deepmd/jax/model/ener_model.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, + Optional, ) from deepmd.dpmodel.model import EnergyModel as EnergyModelDP @@ -10,8 +11,12 @@ from deepmd.jax.common import ( flax_module, ) +from deepmd.jax.env import ( + jnp, +) from deepmd.jax.model.base_model import ( BaseModel, + forward_common_atomic, ) @@ -22,3 +27,24 @@ def __setattr__(self, name: str, value: Any) -> None: if name == "atomic_model": value = DPAtomicModel.deserialize(value.serialize()) return super().__setattr__(name, value) + + def forward_common_atomic( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ): + return forward_common_atomic( + self, + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index e3bf808978..de86cde2a6 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -69,6 +69,8 @@ "INSTALLED_ARRAY_API_STRICT", ] +SKIP_FLAG = object() + class CommonTest(ABC): data: ClassVar[dict] @@ -362,6 +364,8 @@ def test_dp_consistent_with_ref(self): data2 = dp_obj.serialize() np.testing.assert_equal(data1, data2) for rr1, rr2 in zip(ret1, ret2): + if rr1 is SKIP_FLAG or rr2 is SKIP_FLAG: + continue np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index 4112e09cff..144dd1af1c 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -51,7 +51,7 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix): {}, suffix=suffix, ) - return [ret["energy"], ret["atom_ener"]], { + return [ret["energy"], ret["atom_ener"], ret["force"]], { t_coord: coords, t_type: atype, t_natoms: natoms, diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 78a2aac703..210ec82ae3 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -16,6 +16,7 @@ INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, + SKIP_FLAG, CommonTest, parameterized, ) @@ -94,6 +95,21 @@ def data(self) -> dict: jax_class = EnergyModelJAX args = model_args() + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_tf: + return self.RefBackend.TF + if not self.skip_jax: + return self.RefBackend.JAX + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + @property def skip_tf(self): return ( @@ -195,11 +211,19 @@ def eval_jax(self, jax_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.DP: - return (ret["energy_redu"].ravel(), ret["energy"].ravel()) + return (ret["energy_redu"].ravel(), ret["energy"].ravel(), SKIP_FLAG) elif backend is self.RefBackend.PT: - return (ret["energy"].ravel(), ret["atom_energy"].ravel()) + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ) elif backend is self.RefBackend.TF: - return (ret[0].ravel(), ret[1].ravel()) + return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel()) elif backend is self.RefBackend.JAX: - return (ret["energy_redu"].ravel(), ret["energy"].ravel()) + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + ret["energy_derv_r"].ravel(), + ) raise ValueError(f"Unknown backend: {backend}") From d0b576f2c922e0b362d2edc103ee8cc2d7e4da9a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 24 Oct 2024 22:55:01 -0400 Subject: [PATCH 2/8] do not change diff directly Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/utils/env_mat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 3cd729f553..aa8520202e 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -62,8 +62,8 @@ def _make_env_mat( diff = coord_r - coord_l # nf x nloc x nnei # the grad of JAX vector_norm is NaN at x=0 - diff = xp.where(xp.abs(diff) < 1e-30, xp.full_like(diff, 1e-30), diff) - length = xp.linalg.vector_norm(diff, axis=-1, keepdims=True) + diff_ = xp.where(xp.abs(diff) < 1e-30, xp.full_like(diff, 1e-30), diff) + length = xp.linalg.vector_norm(diff_, axis=-1, keepdims=True) # for index 0 nloc atom length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype) t0 = 1 / (length + protection) From 004b89a886a26b49dcef8b0a298cd45532537a07 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 25 Oct 2024 01:36:17 -0400 Subject: [PATCH 3/8] support virial Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/model/transform_output.py | 53 +++++++++++++++++++--- deepmd/jax/model/base_model.py | 19 ++++++-- source/tests/consistent/model/common.py | 2 +- source/tests/consistent/model/test_ener.py | 11 ++++- 4 files changed, 72 insertions(+), 13 deletions(-) diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index 63b9818397..1e5e57377a 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -52,7 +52,20 @@ def get_leading_dims( vv: np.ndarray, vdef: OutputVariableDef, ): - """Get the dimensions of nf x nloc.""" + """Get the dimensions of nf x nloc. + + Parameters + ---------- + vv : np.ndarray + The input array from which to compute the leading dimensions. + vdef : OutputVariableDef + The output variable definition containing the shape to exclude from `vv`. + + Returns + ------- + list + A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions. + """ vshape = vv.shape return list(vshape[: (len(vshape) - len(vdef.shape))]) @@ -76,11 +89,11 @@ def communicate_extended_output( if vdef.reducible: kk_redu = get_reduce_name(kk) new_ret[kk_redu] = model_ret[kk_redu] + kk_derv_r, kk_derv_c = get_deriv_name(kk) + mldims = list(mapping.shape) + vldims = get_leading_dims(vv, vdef) if vdef.r_differentiable: - kk_derv_r, kk_derv_c = get_deriv_name(kk) if model_ret[kk_derv_r] is not None: - mldims = list(mapping.shape) - vldims = get_leading_dims(vv, vdef) derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005 mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims))) mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims) @@ -109,9 +122,35 @@ def communicate_extended_output( new_ret[kk_derv_r] = None if vdef.c_differentiable: assert vdef.r_differentiable - kk_derv_r, kk_derv_c = get_deriv_name(kk) - new_ret[kk_derv_c] = None - new_ret[kk_derv_c + "_redu"] = None + if model_ret[kk_derv_c] is not None: + derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005 + mapping = xp.tile( + mapping, [1] * (len(mldims) + len(vdef.shape)) + [3] + ) + virial = xp.zeros( + vldims + derv_c_ext_dims, dtype=vv.dtype, device=vv.device + ) + # jax only + if array_api_compat.is_jax_array(virial): + from deepmd.jax.env import ( + jnp, + ) + + v_idx = xp.arange(virial.size, dtype=xp.int64).reshape( + virial.shape + ) + new_idx = jnp.take_along_axis(v_idx, mapping, axis=1).ravel() + v_shape = virial.shape + virial = virial.ravel() + virial = virial.at[new_idx].add(model_ret[kk_derv_c].ravel()) + virial = virial.reshape(v_shape) + else: + raise NotImplementedError("Only JAX arrays are supported.") + new_ret[kk_derv_c] = virial + new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1) + else: + new_ret[kk_derv_c] = None + new_ret[kk_derv_c + "_redu"] = None if not do_atomic_virial: # pop atomic virial, because it is not correctly calculated. new_ret.pop(kk_derv_c) diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 3d4450468b..4c847792ee 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -53,6 +53,7 @@ def forward_common_atomic( size *= ii split_ff = [] + split_vv = [] for ss in range(size): def eval_output( @@ -76,13 +77,25 @@ def eval_output( fparam, aparam, ) + aviri = ffi[..., None] @ extended_coord[..., None, :] ffi = ffi[..., None, :] split_ff.append(ffi) + aviri = aviri[..., None, :] + split_vv.append(aviri) out_lead_shape = list(extended_coord.shape[:-1]) + vdef.shape - ff = jnp.concatenate(split_ff, axis=-2).reshape(*out_lead_shape, 3) + extended_force = jnp.concat(split_ff, axis=-2).reshape( + *out_lead_shape, 3 + ) - model_predict[kk_derv_r] = ff + model_predict[kk_derv_r] = extended_force if vdef.c_differentiable: assert vdef.r_differentiable - model_predict[kk_derv_c] = None + extended_virial = jnp.concat(split_vv, axis=-2).reshape( + *out_lead_shape, 9 + ) + # the correction sums to zero, which does not contribute to global virial + if do_atomic_virial: + raise NotImplementedError("Atomic virial is not implemented yet.") + # to [...,3,3] -> [...,9] + model_predict[kk_derv_c] = extended_virial return model_predict diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index 144dd1af1c..11940d9bdf 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -51,7 +51,7 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix): {}, suffix=suffix, ) - return [ret["energy"], ret["atom_ener"], ret["force"]], { + return [ret["energy"], ret["atom_ener"], ret["force"], ret["virial"]], { t_coord: coords, t_type: atype, t_natoms: natoms, diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 210ec82ae3..2a358ba7e0 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -211,19 +211,26 @@ def eval_jax(self, jax_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.DP: - return (ret["energy_redu"].ravel(), ret["energy"].ravel(), SKIP_FLAG) + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + SKIP_FLAG, + SKIP_FLAG, + ) elif backend is self.RefBackend.PT: return ( ret["energy"].ravel(), ret["atom_energy"].ravel(), ret["force"].ravel(), + ret["virial"].ravel(), ) elif backend is self.RefBackend.TF: - return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel()) + return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel()) elif backend is self.RefBackend.JAX: return ( ret["energy_redu"].ravel(), ret["energy"].ravel(), ret["energy_derv_r"].ravel(), + ret["energy_derv_c_redu"].ravel(), ) raise ValueError(f"Unknown backend: {backend}") From b9eefd3afd8e9728b95dda9abb38702ea033a88b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 25 Oct 2024 01:57:16 -0400 Subject: [PATCH 4/8] resolve comments Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/model/transform_output.py | 32 +++++++++++------------- deepmd/jax/common.py | 10 ++++++++ deepmd/jax/model/base_model.py | 13 ++++++++-- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index 1e5e57377a..6b4097b889 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -102,18 +102,16 @@ def communicate_extended_output( ) # jax only if array_api_compat.is_jax_array(force): - from deepmd.jax.env import ( - jnp, + from deepmd.jax.common import ( + scatter_sum, ) - f_idx = xp.arange(force.size, dtype=xp.int64).reshape( - force.shape + force = scatter_sum( + force, + 1, + mapping, + model_ret[kk_derv_r], ) - new_idx = jnp.take_along_axis(f_idx, mapping, axis=1).ravel() - f_shape = force.shape - force = force.ravel() - force = force.at[new_idx].add(model_ret[kk_derv_r].ravel()) - force = force.reshape(f_shape) else: raise NotImplementedError("Only JAX arrays are supported.") new_ret[kk_derv_r] = force @@ -132,18 +130,16 @@ def communicate_extended_output( ) # jax only if array_api_compat.is_jax_array(virial): - from deepmd.jax.env import ( - jnp, + from deepmd.jax.common import ( + scatter_sum, ) - v_idx = xp.arange(virial.size, dtype=xp.int64).reshape( - virial.shape + virial = scatter_sum( + virial, + 1, + mapping, + model_ret[kk_derv_c], ) - new_idx = jnp.take_along_axis(v_idx, mapping, axis=1).ravel() - v_shape = virial.shape - virial = virial.ravel() - virial = virial.at[new_idx].add(model_ret[kk_derv_c].ravel()) - virial = virial.reshape(v_shape) else: raise NotImplementedError("Only JAX arrays are supported.") new_ret[kk_derv_c] = virial diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index f372e97eb5..59f36d11ad 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -95,3 +95,13 @@ def __dlpack__(self, *args, **kwargs): def __dlpack_device__(self, *args, **kwargs): return self.value.__dlpack_device__(*args, **kwargs) + + +def scatter_sum(input, dim, index: jnp.ndarray, src: jnp.ndarray) -> jnp.ndarray: + """Reduces all values from the src tensor to the indices specified in the index tensor.""" + idx = jnp.arange(input.size, dtype=jnp.int64).reshape(input.shape) + new_idx = jnp.take_along_axis(idx, index, axis=dim).ravel() + shape = input.shape + input = input.ravel() + input = input.at[new_idx].add(src.ravel()) + return input.reshape(shape) diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 4c847792ee..46d3761073 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -57,7 +57,16 @@ def forward_common_atomic( for ss in range(size): def eval_output( - cc_ext, extended_atype, nlist, mapping, fparam, aparam + cc_ext, + extended_atype, + nlist, + mapping, + fparam, + aparam, + *, + _kk=kk, + _ss=ss, + _atom_axis=atom_axis, ): atomic_ret = self.atomic_model.forward_common_atomic( cc_ext[None, ...], @@ -67,7 +76,7 @@ def eval_output( fparam=fparam[None, ...] if fparam is not None else None, aparam=aparam[None, ...] if aparam is not None else None, ) - return jnp.sum(atomic_ret[kk][0], axis=atom_axis)[ss] + return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis)[_ss] ffi = -jax.vmap(jax.grad(eval_output, argnums=0))( extended_coord, From 1bbb1c2b8832f1d96ebbdaa4ba6f27416861ec0f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 25 Oct 2024 02:53:20 -0400 Subject: [PATCH 5/8] remove devices Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/model/transform_output.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index 6b4097b889..dc46f4a9ce 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -97,9 +97,7 @@ def communicate_extended_output( derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005 mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims))) mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims) - force = xp.zeros( - vldims + derv_r_ext_dims, dtype=vv.dtype, device=vv.device - ) + force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype) # jax only if array_api_compat.is_jax_array(force): from deepmd.jax.common import ( @@ -126,7 +124,8 @@ def communicate_extended_output( mapping, [1] * (len(mldims) + len(vdef.shape)) + [3] ) virial = xp.zeros( - vldims + derv_c_ext_dims, dtype=vv.dtype, device=vv.device + vldims + derv_c_ext_dims, + dtype=vv.dtype, ) # jax only if array_api_compat.is_jax_array(virial): From 23e318a61a19f3618070bc80e89e0209796d4513 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 25 Oct 2024 19:21:26 -0400 Subject: [PATCH 6/8] do_atomic_virial cannot jit Signed-off-by: Jinzhe Zeng --- deepmd/jax/model/base_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 46d3761073..b77c54abec 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -103,8 +103,9 @@ def eval_output( *out_lead_shape, 9 ) # the correction sums to zero, which does not contribute to global virial - if do_atomic_virial: - raise NotImplementedError("Atomic virial is not implemented yet.") + # cannot jit + # if do_atomic_virial: + # raise NotImplementedError("Atomic virial is not implemented yet.") # to [...,3,3] -> [...,9] model_predict[kk_derv_c] = extended_virial return model_predict From c34d8b48331911a6cd47bd8cee6e33ddad95ec3e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 26 Oct 2024 03:19:44 -0400 Subject: [PATCH 7/8] fix bug before following refactor Signed-off-by: Jinzhe Zeng --- deepmd/jax/model/base_model.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index b77c54abec..9aa327ba6f 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -78,6 +78,8 @@ def eval_output( ) return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis)[_ss] + # extended_coord: [nf, nall, 3] + # ffi: [nf, nall, 3] ffi = -jax.vmap(jax.grad(eval_output, argnums=0))( extended_coord, extended_atype, @@ -86,12 +88,21 @@ def eval_output( fparam, aparam, ) + # ffi[..., None]: [nf, nall, 3, 1] + # extended_coord[..., None, :]: [nf, nall, 1, 3] + # aviri: [nf, nall, 3, 3] aviri = ffi[..., None] @ extended_coord[..., None, :] + # aviri: [nf, nall, 9] + aviri = aviri.reshape(*aviri.shape[:-2], 9) + # ffi: [nf, nall, 1, 3] ffi = ffi[..., None, :] split_ff.append(ffi) + # aviri: [nf, nall, 1, 9] aviri = aviri[..., None, :] split_vv.append(aviri) out_lead_shape = list(extended_coord.shape[:-1]) + vdef.shape + # extended_force: [nf, nall, def_size, 3] + # extended_force: [nf, nall, *def, 3] extended_force = jnp.concat(split_ff, axis=-2).reshape( *out_lead_shape, 3 ) @@ -99,6 +110,8 @@ def eval_output( model_predict[kk_derv_r] = extended_force if vdef.c_differentiable: assert vdef.r_differentiable + # extended_virial: [nf, nall, def_size, 9] + # extended_virial: [nf, nall, *def, 9] extended_virial = jnp.concat(split_vv, axis=-2).reshape( *out_lead_shape, 9 ) From 19e09708d679f12460218b0546fdf558de6cb68a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 26 Oct 2024 03:25:30 -0400 Subject: [PATCH 8/8] refactor Signed-off-by: Jinzhe Zeng --- deepmd/jax/model/base_model.py | 97 ++++++++++++++-------------------- 1 file changed, 40 insertions(+), 57 deletions(-) diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 9aa327ba6f..8631c85d16 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -48,73 +48,56 @@ def forward_common_atomic( model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis) kk_derv_r, kk_derv_c = get_deriv_name(kk) if vdef.c_differentiable: - size = 1 - for ii in vdef.shape: - size *= ii - split_ff = [] - split_vv = [] - for ss in range(size): - - def eval_output( - cc_ext, - extended_atype, - nlist, - mapping, - fparam, - aparam, - *, - _kk=kk, - _ss=ss, - _atom_axis=atom_axis, - ): - atomic_ret = self.atomic_model.forward_common_atomic( - cc_ext[None, ...], - extended_atype[None, ...], - nlist[None, ...], - mapping=mapping[None, ...] if mapping is not None else None, - fparam=fparam[None, ...] if fparam is not None else None, - aparam=aparam[None, ...] if aparam is not None else None, - ) - return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis)[_ss] - - # extended_coord: [nf, nall, 3] - # ffi: [nf, nall, 3] - ffi = -jax.vmap(jax.grad(eval_output, argnums=0))( - extended_coord, - extended_atype, - nlist, - mapping, - fparam, - aparam, + def eval_output( + cc_ext, + extended_atype, + nlist, + mapping, + fparam, + aparam, + *, + _kk=kk, + _atom_axis=atom_axis, + ): + atomic_ret = self.atomic_model.forward_common_atomic( + cc_ext[None, ...], + extended_atype[None, ...], + nlist[None, ...], + mapping=mapping[None, ...] if mapping is not None else None, + fparam=fparam[None, ...] if fparam is not None else None, + aparam=aparam[None, ...] if aparam is not None else None, ) - # ffi[..., None]: [nf, nall, 3, 1] - # extended_coord[..., None, :]: [nf, nall, 1, 3] - # aviri: [nf, nall, 3, 3] - aviri = ffi[..., None] @ extended_coord[..., None, :] - # aviri: [nf, nall, 9] - aviri = aviri.reshape(*aviri.shape[:-2], 9) - # ffi: [nf, nall, 1, 3] - ffi = ffi[..., None, :] - split_ff.append(ffi) - # aviri: [nf, nall, 1, 9] - aviri = aviri[..., None, :] - split_vv.append(aviri) - out_lead_shape = list(extended_coord.shape[:-1]) + vdef.shape - # extended_force: [nf, nall, def_size, 3] + return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis) + + # extended_coord: [nf, nall, 3] + # ff: [nf, *def, nall, 3] + ff = -jax.vmap(jax.jacrev(eval_output, argnums=0))( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) # extended_force: [nf, nall, *def, 3] - extended_force = jnp.concat(split_ff, axis=-2).reshape( - *out_lead_shape, 3 + def_ndim = len(vdef.shape) + extended_force = jnp.transpose( + ff, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] ) model_predict[kk_derv_r] = extended_force if vdef.c_differentiable: assert vdef.r_differentiable - # extended_virial: [nf, nall, def_size, 9] + # avr: [nf, *def, nall, 3, 3] + avr = jnp.einsum("f...ai,faj->f...aij", ff, extended_coord) + # avr: [nf, *def, nall, 9] + avr = jnp.reshape(avr, [*ff.shape[:-1], 9]) # extended_virial: [nf, nall, *def, 9] - extended_virial = jnp.concat(split_vv, axis=-2).reshape( - *out_lead_shape, 9 + extended_virial = jnp.transpose( + avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] ) + # the correction sums to zero, which does not contribute to global virial # cannot jit # if do_atomic_virial: