From 71255cb99b2100c82f0df21c1cf6ec3984111be7 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 23:37:40 -0500 Subject: [PATCH] fix(jax): calculate virial in call_lower Signed-off-by: Jinzhe Zeng --- deepmd/jax/model/base_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 1e880700a2..44152a4c26 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -152,4 +152,6 @@ def eval_ce( avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] ) model_predict[kk_derv_c] = extended_virial + # [nf, *def, 9] + model_predict[kk_derv_c + "_redu"] = jnp.sum(extended_virial, axis=1) return model_predict