Skip to content

Commit

Permalink
fix(jax): calculate virial in call_lower
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 4, 2024
1 parent bfbe2ed commit 71255cb
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,6 @@ def eval_ce(
avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2]
)
model_predict[kk_derv_c] = extended_virial
# [nf, *def, 9]
model_predict[kk_derv_c + "_redu"] = jnp.sum(extended_virial, axis=1)
return model_predict

0 comments on commit 71255cb

Please sign in to comment.