Skip to content

Commit

Permalink
remove devices
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 25, 2024
1 parent b9eefd3 commit 1bbb1c2
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def communicate_extended_output(
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
force = xp.zeros(
vldims + derv_r_ext_dims, dtype=vv.dtype, device=vv.device
)
force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype)
# jax only
if array_api_compat.is_jax_array(force):
from deepmd.jax.common import (
Expand All @@ -126,7 +124,8 @@ def communicate_extended_output(
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
)
virial = xp.zeros(
vldims + derv_c_ext_dims, dtype=vv.dtype, device=vv.device
vldims + derv_c_ext_dims,
dtype=vv.dtype,
)
# jax only
if array_api_compat.is_jax_array(virial):
Expand Down

0 comments on commit 1bbb1c2

Please sign in to comment.