diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 422dcb5f17..ba8858d6b9 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -58,7 +58,7 @@ jobs: env: NUM_WORKERS: 0 - name: Test TF2 eager mode - run: pytest --cov=deepmd source/tests/consistent/io/test_io.py --durations=0 + run: pytest --cov=deepmd --cov-append source/tests/consistent/io/test_io.py source/jax2tf_tests --durations=0 env: NUM_WORKERS: 0 DP_TEST_TF2_ONLY: 1 diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index fc526a502e..b9d1974c27 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -97,6 +97,12 @@ def __init__( stablehlo_atomic_virial=model_data["@variables"][ "stablehlo_atomic_virial" ].tobytes(), + stablehlo_no_ghost=model_data["@variables"][ + "stablehlo_no_ghost" + ].tobytes(), + stablehlo_atomic_virial_no_ghost=model_data["@variables"][ + "stablehlo_atomic_virial_no_ghost" + ].tobytes(), model_def_script=model_data["model_def_script"], **model_data["constants"], ) diff --git a/deepmd/jax/jax2tf/__init__.py b/deepmd/jax/jax2tf/__init__.py index 88a928f04d..c2cda24bd7 100644 --- a/deepmd/jax/jax2tf/__init__.py +++ b/deepmd/jax/jax2tf/__init__.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import tensorflow as tf +import tensorflow.experimental.numpy as tnp if not tf.executing_eagerly(): # TF disallow temporary eager execution @@ -9,3 +10,5 @@ "If you are converting a model between different backends, " "considering converting to the `.dp` format first." ) + +tnp.experimental_enable_numpy_behavior() diff --git a/deepmd/jax/jax2tf/make_model.py b/deepmd/jax/jax2tf/make_model.py new file mode 100644 index 0000000000..d21fc998b5 --- /dev/null +++ b/deepmd/jax/jax2tf/make_model.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, +) + +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from deepmd.dpmodel.output_def import ( + ModelOutputDef, +) +from deepmd.jax.jax2tf.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.jax.jax2tf.region import ( + normalize_coord, +) +from deepmd.jax.jax2tf.transform_output import ( + communicate_extended_output, +) + + +def model_call_from_call_lower( + *, # enforce keyword-only arguments + call_lower: Callable[ + [ + tnp.ndarray, + tnp.ndarray, + tnp.ndarray, + tnp.ndarray, + tnp.ndarray, + bool, + ], + dict[str, tnp.ndarray], + ], + rcut: float, + sel: list[int], + mixed_types: bool, + model_output_def: ModelOutputDef, + coord: tnp.ndarray, + atype: tnp.ndarray, + box: tnp.ndarray, + fparam: tnp.ndarray, + aparam: tnp.ndarray, + do_atomic_virial: bool = False, +): + """Return model prediction from lower interface. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,tnp.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + atype_shape = tf.shape(atype) + nframes, nloc = atype_shape[0], atype_shape[1] + cc, bb, fp, ap = coord, box, fparam, aparam + del coord, box, fparam, aparam + if tf.shape(bb)[-1] != 0: + coord_normalized = normalize_coord( + cc.reshape(nframes, nloc, 3), + bb.reshape(nframes, 3, 3), + ) + else: + coord_normalized = cc + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, bb, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + distinguish_types=not mixed_types, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + model_predict_lower = call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fp, + aparam=ap, + ) + model_predict = communicate_extended_output( + model_predict_lower, + model_output_def, + mapping, + do_atomic_virial=do_atomic_virial, + ) + return model_predict diff --git a/deepmd/jax/jax2tf/nlist.py b/deepmd/jax/jax2tf/nlist.py new file mode 100644 index 0000000000..5a0ed58b63 --- /dev/null +++ b/deepmd/jax/jax2tf/nlist.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Union, +) + +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from .region import ( + to_face_distance, +) + + +## translated from torch implementation by chatgpt +def build_neighbor_list( + coord: tnp.ndarray, + atype: tnp.ndarray, + nloc: int, + rcut: float, + sel: Union[int, list[int]], + distinguish_types: bool = True, +) -> tnp.ndarray: + """Build neighbor list for a single frame. keeps nsel neighbors. + + Parameters + ---------- + coord : tnp.ndarray + exptended coordinates of shape [batch_size, nall x 3] + atype : tnp.ndarray + extended atomic types of shape [batch_size, nall] + type < 0 the atom is treat as virtual atoms. + nloc : int + number of local atoms. + rcut : float + cut-off radius + sel : int or list[int] + maximal number of neighbors (of each type). + if distinguish_types==True, nsel should be list and + the length of nsel should be equal to number of + types. + distinguish_types : bool + distinguish different types. + + Returns + ------- + neighbor_list : tnp.ndarray + Neighbor list of shape [batch_size, nloc, nsel], the neighbors + are stored in an ascending order. If the number of + neighbors is less than nsel, the positions are masked + with -1. The neighbor list of an atom looks like + |------ nsel ------| + xx xx xx xx -1 -1 -1 + if distinguish_types==True and we have two types + |---- nsel[0] -----| |---- nsel[1] -----| + xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1 + For virtual atoms all neighboring positions are filled with -1. + + """ + batch_size = tf.shape(coord)[0] + coord = tnp.reshape(coord, (batch_size, -1)) + nall = tf.shape(coord)[1] // 3 + # fill virtual atoms with large coords so they are not neighbors of any + # real atom. + if tf.size(coord) > 0: + xmax = tnp.max(coord) + 2.0 * rcut + else: + xmax = tf.cast(2.0 * rcut, coord.dtype) + # nf x nall + is_vir = atype < 0 + coord1 = tnp.where( + is_vir[:, :, None], xmax, tnp.reshape(coord, (batch_size, nall, 3)) + ) + coord1 = tnp.reshape(coord1, (batch_size, nall * 3)) + if isinstance(sel, int): + sel = [sel] + nsel = sum(sel) + coord0 = coord1[:, : nloc * 3] + diff = ( + tnp.reshape(coord1, [batch_size, -1, 3])[:, None, :, :] + - tnp.reshape(coord0, [batch_size, -1, 3])[:, :, None, :] + ) + rr = tf.linalg.norm(diff, axis=-1) + # if central atom has two zero distances, sorting sometimes can not exclude itself + rr -= tf.eye(nloc, nall, dtype=diff.dtype)[tnp.newaxis, :, :] + nlist = tnp.argsort(rr, axis=-1) + rr = tnp.sort(rr, axis=-1) + rr = rr[:, :, 1:] + nlist = nlist[:, :, 1:] + nnei = tf.shape(rr)[2] + if nsel <= nnei: + rr = rr[:, :, :nsel] + nlist = nlist[:, :, :nsel] + else: + rr = tnp.concatenate( + [rr, tnp.ones([batch_size, nloc, nsel - nnei], dtype=rr.dtype) + rcut], + axis=-1, + ) + nlist = tnp.concatenate( + [nlist, tnp.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)], + axis=-1, + ) + nlist = tnp.where( + tnp.logical_or((rr > rcut), is_vir[:, :nloc, None]), + tnp.full_like(nlist, -1), + nlist, + ) + + if distinguish_types: + return nlist_distinguish_types(nlist, atype, sel) + else: + return nlist + + +def nlist_distinguish_types( + nlist: tnp.ndarray, + atype: tnp.ndarray, + sel: list[int], +): + """Given a nlist that does not distinguish atom types, return a nlist that + distinguish atom types. + + """ + nloc = tf.shape(nlist)[1] + ret_nlist = [] + tmp_atype = tnp.tile(atype[:, None, :], (1, nloc, 1)) + mask = nlist == -1 + tnlist_0 = tnp.where(mask, tnp.zeros_like(nlist), nlist) + tnlist = tnp.take_along_axis(tmp_atype, tnlist_0, axis=2) + tnlist = tnp.where(mask, tnp.full_like(tnlist, -1), tnlist) + for ii, ss in enumerate(sel): + pick_mask = tf.cast(tnlist == ii, tnp.int32) + sorted_indices = tnp.argsort(-pick_mask, kind="stable", axis=-1) + pick_mask_sorted = -tnp.sort(-pick_mask, axis=-1) + inlist = tnp.take_along_axis(nlist, sorted_indices, axis=2) + inlist = tnp.where( + ~tf.cast(pick_mask_sorted, tf.bool), tnp.full_like(inlist, -1), inlist + ) + ret_nlist.append(inlist[..., :ss]) + ret = tf.concat(ret_nlist, axis=-1) + return ret + + +def tf_outer(a, b): + return tf.einsum("i,j->ij", a, b) + + +## translated from torch implementation by chatgpt +def extend_coord_with_ghosts( + coord: tnp.ndarray, + atype: tnp.ndarray, + cell: tnp.ndarray, + rcut: float, +): + """Extend the coordinates of the atoms by appending peridoc images. + The number of images is large enough to ensure all the neighbors + within rcut are appended. + + Parameters + ---------- + coord : tnp.ndarray + original coordinates of shape [-1, nloc*3]. + atype : tnp.ndarray + atom type of shape [-1, nloc]. + cell : tnp.ndarray + simulation cell tensor of shape [-1, 9]. + rcut : float + the cutoff radius + + Returns + ------- + extended_coord: tnp.ndarray + extended coordinates of shape [-1, nall*3]. + extended_atype: tnp.ndarray + extended atom type of shape [-1, nall]. + index_mapping: tnp.ndarray + mapping extended index to the local index + + """ + atype_shape = tf.shape(atype) + nf, nloc = atype_shape[0], atype_shape[1] + # int64 for index + aidx = tf.range(nloc, dtype=tnp.int64) + aidx = tnp.tile(aidx[tnp.newaxis, :], (nf, 1)) + if tf.shape(cell)[-1] == 0: + nall = nloc + extend_coord = coord + extend_atype = atype + extend_aidx = aidx + else: + coord = tnp.reshape(coord, (nf, nloc, 3)) + cell = tnp.reshape(cell, (nf, 3, 3)) + to_face = to_face_distance(cell) + nbuff = tf.cast(tnp.ceil(rcut / to_face), tnp.int64) + nbuff = tnp.max(nbuff, axis=0) + xi = tf.range(-nbuff[0], nbuff[0] + 1, 1, dtype=tnp.int64) + yi = tf.range(-nbuff[1], nbuff[1] + 1, 1, dtype=tnp.int64) + zi = tf.range(-nbuff[2], nbuff[2] + 1, 1, dtype=tnp.int64) + xyz = tf_outer(xi, tnp.asarray([1, 0, 0]))[:, tnp.newaxis, tnp.newaxis, :] + xyz = xyz + tf_outer(yi, tnp.asarray([0, 1, 0]))[tnp.newaxis, :, tnp.newaxis, :] + xyz = xyz + tf_outer(zi, tnp.asarray([0, 0, 1]))[tnp.newaxis, tnp.newaxis, :, :] + xyz = tnp.reshape(xyz, (-1, 3)) + xyz = tf.cast(xyz, coord.dtype) + shift_idx = tnp.take(xyz, tnp.argsort(tf.linalg.norm(xyz, axis=1)), axis=0) + ns = tf.shape(shift_idx)[0] + nall = ns * nloc + shift_vec = tnp.einsum("sd,fdk->fsk", shift_idx, cell) + # shift_vec = tnp.tensordot(shift_idx, cell, axes=([1], [1])) + # shift_vec = tnp.transpose(shift_vec, (1, 0, 2)) + extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :] + extend_atype = tnp.tile(atype[:, :, tnp.newaxis], (1, ns, 1)) + extend_aidx = tnp.tile(aidx[:, :, tnp.newaxis], (1, ns, 1)) + + return ( + tnp.reshape(extend_coord, (nf, nall * 3)), + tnp.reshape(extend_atype, (nf, nall)), + tnp.reshape(extend_aidx, (nf, nall)), + ) diff --git a/deepmd/jax/jax2tf/region.py b/deepmd/jax/jax2tf/region.py new file mode 100644 index 0000000000..96024bd79a --- /dev/null +++ b/deepmd/jax/jax2tf/region.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + + +def phys2inter( + coord: tnp.ndarray, + cell: tnp.ndarray, +) -> tnp.ndarray: + """Convert physical coordinates to internal(direct) coordinates. + + Parameters + ---------- + coord : tnp.ndarray + physical coordinates of shape [*, na, 3]. + cell : tnp.ndarray + simulation cell tensor of shape [*, 3, 3]. + + Returns + ------- + inter_coord: tnp.ndarray + the internal coordinates + + """ + rec_cell = tf.linalg.inv(cell) + return tnp.matmul(coord, rec_cell) + + +def inter2phys( + coord: tnp.ndarray, + cell: tnp.ndarray, +) -> tnp.ndarray: + """Convert internal(direct) coordinates to physical coordinates. + + Parameters + ---------- + coord : tnp.ndarray + internal coordinates of shape [*, na, 3]. + cell : tnp.ndarray + simulation cell tensor of shape [*, 3, 3]. + + Returns + ------- + phys_coord: tnp.ndarray + the physical coordinates + + """ + return tnp.matmul(coord, cell) + + +def normalize_coord( + coord: tnp.ndarray, + cell: tnp.ndarray, +) -> tnp.ndarray: + """Apply PBC according to the atomic coordinates. + + Parameters + ---------- + coord : tnp.ndarray + original coordinates of shape [*, na, 3]. + cell : tnp.ndarray + simulation cell shape [*, 3, 3]. + + Returns + ------- + wrapped_coord: tnp.ndarray + wrapped coordinates of shape [*, na, 3]. + + """ + icoord = phys2inter(coord, cell) + icoord = tnp.remainder(icoord, 1.0) + return inter2phys(icoord, cell) + + +def to_face_distance( + cell: tnp.ndarray, +) -> tnp.ndarray: + """Compute the to-face-distance of the simulation cell. + + Parameters + ---------- + cell : tnp.ndarray + simulation cell tensor of shape [*, 3, 3]. + + Returns + ------- + dist: tnp.ndarray + the to face distances of shape [*, 3] + + """ + cshape = tf.shape(cell) + dist = b_to_face_distance(tnp.reshape(cell, [-1, 3, 3])) + return tnp.reshape(dist, tf.concat([cshape[:-2], [3]], axis=0)) + + +def b_to_face_distance(cell): + volume = tf.linalg.det(cell) + c_yz = tf.linalg.cross(cell[:, 1, ...], cell[:, 2, ...]) + _h2yz = volume / tf.linalg.norm(c_yz, axis=-1) + c_zx = tf.linalg.cross(cell[:, 2, ...], cell[:, 0, ...]) + _h2zx = volume / tf.linalg.norm(c_zx, axis=-1) + c_xy = tf.linalg.cross(cell[:, 0, ...], cell[:, 1, ...]) + _h2xy = volume / tf.linalg.norm(c_xy, axis=-1) + return tnp.stack([_h2yz, _h2zx, _h2xy], axis=1) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index dff43a11fc..7e560f6008 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -1,11 +1,18 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json +from typing import ( + Optional, +) import tensorflow as tf +import tensorflow.experimental.numpy as tnp from jax.experimental import ( jax2tf, ) +from deepmd.jax.jax2tf.make_model import ( + model_call_from_call_lower, +) from deepmd.jax.model.base_model import ( BaseModel, ) @@ -28,7 +35,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None: tf_model = tf.Module() - def exported_whether_do_atomic_virial(do_atomic_virial): + def exported_whether_do_atomic_virial(do_atomic_virial, has_ghost_atoms): def call_lower_with_fixed_do_atomic_virial( coord, atype, nlist, mapping, fparam, aparam ): @@ -42,13 +49,20 @@ def call_lower_with_fixed_do_atomic_virial( do_atomic_virial=do_atomic_virial, ) + # nghost >= 1 is assumed if there is + # other workaround does not work, such as + # nall; nloc + nghost - 1 + if has_ghost_atoms: + nghost = "nghost" + else: + nghost = "0" return jax2tf.convert( call_lower_with_fixed_do_atomic_virial, polymorphic_shapes=[ - "(nf, nloc + nghost, 3)", - "(nf, nloc + nghost)", + f"(nf, nloc + {nghost}, 3)", + f"(nf, nloc + {nghost})", f"(nf, nloc, {model.get_nnei()})", - "(nf, nloc + nghost)", + f"(nf, nloc + {nghost})", f"(nf, {model.get_dim_fparam()})", f"(nf, nloc, {model.get_dim_aparam()})", ], @@ -71,8 +85,14 @@ def call_lower_with_fixed_do_atomic_virial( def call_lower_without_atomic_virial( coord, atype, nlist, mapping, fparam, aparam ): - return exported_whether_do_atomic_virial(do_atomic_virial=False)( - coord, atype, nlist, mapping, fparam, aparam + return tf.cond( + tf.shape(coord)[1] == tf.shape(nlist)[1], + lambda: exported_whether_do_atomic_virial( + do_atomic_virial=False, has_ghost_atoms=False + )(coord, atype, nlist, mapping, fparam, aparam), + lambda: exported_whether_do_atomic_virial( + do_atomic_virial=False, has_ghost_atoms=True + )(coord, atype, nlist, mapping, fparam, aparam), ) tf_model.call_lower = call_lower_without_atomic_virial @@ -89,12 +109,116 @@ def call_lower_without_atomic_virial( ], ) def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam): - return exported_whether_do_atomic_virial(do_atomic_virial=True)( - coord, atype, nlist, mapping, fparam, aparam + return tf.cond( + tf.shape(coord)[1] == tf.shape(nlist)[1], + lambda: exported_whether_do_atomic_virial( + do_atomic_virial=True, has_ghost_atoms=False + )(coord, atype, nlist, mapping, fparam, aparam), + lambda: exported_whether_do_atomic_virial( + do_atomic_virial=True, has_ghost_atoms=True + )(coord, atype, nlist, mapping, fparam, aparam), ) tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial + def make_call_whether_do_atomic_virial(do_atomic_virial: bool): + if do_atomic_virial: + call_lower = call_lower_with_atomic_virial + else: + call_lower = call_lower_without_atomic_virial + + def call( + coord: tnp.ndarray, + atype: tnp.ndarray, + box: Optional[tnp.ndarray] = None, + fparam: Optional[tnp.ndarray] = None, + aparam: Optional[tnp.ndarray] = None, + ): + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + + Returns + ------- + ret_dict + The result dict of type dict[str,jnp.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + return model_call_from_call_lower( + call_lower=call_lower, + rcut=model.get_rcut(), + sel=model.get_sel(), + mixed_types=model.mixed_types(), + model_output_def=model.model_output_def(), + coord=coord, + atype=atype, + box=box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + return call + + @tf.function( + autograph=True, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, None], tf.float64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], + ) + def call_with_atomic_virial( + coord: tnp.ndarray, + atype: tnp.ndarray, + box: tnp.ndarray, + fparam: tnp.ndarray, + aparam: tnp.ndarray, + ): + return make_call_whether_do_atomic_virial(do_atomic_virial=True)( + coord, atype, box, fparam, aparam + ) + + tf_model.call_atomic_virial = call_with_atomic_virial + + @tf.function( + autograph=True, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, None], tf.float64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], + ) + def call_without_atomic_virial( + coord: tnp.ndarray, + atype: tnp.ndarray, + box: tnp.ndarray, + fparam: tnp.ndarray, + aparam: tnp.ndarray, + ): + return make_call_whether_do_atomic_virial(do_atomic_virial=False)( + coord, atype, box, fparam, aparam + ) + + tf_model.call = call_without_atomic_virial + # set functions to export other attributes @tf.function def get_type_map(): diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py index 8f04014a97..0d7b13ba1f 100644 --- a/deepmd/jax/jax2tf/tfmodel.py +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -7,9 +7,6 @@ import jax.experimental.jax2tf as jax2tf import tensorflow as tf -from deepmd.dpmodel.model.make_model import ( - model_call_from_call_lower, -) from deepmd.dpmodel.output_def import ( FittingOutputDef, ModelOutputDef, @@ -55,6 +52,8 @@ def __init__( self._call_lower_atomic_virial = jax2tf.call_tf( self.model.call_lower_atomic_virial ) + self._call = jax2tf.call_tf(self.model.call) + self._call_atomic_virial = jax2tf.call_tf(self.model.call_atomic_virial) self.type_map = decode_list_of_bytes(self.model.get_type_map().numpy().tolist()) self.rcut = self.model.get_rcut().numpy().item() self.dim_fparam = self.model.get_dim_fparam().numpy().item() @@ -142,18 +141,28 @@ def call( The keys are defined by the `ModelOutputDef`. """ - return model_call_from_call_lower( - call_lower=self.call_lower, - rcut=self.get_rcut(), - sel=self.get_sel(), - mixed_types=self.mixed_types(), - model_output_def=self.model_output_def(), - coord=coord, - atype=atype, - box=box, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, + if do_atomic_virial: + call = self._call_atomic_virial + else: + call = self._call + # Attempt to convert a value (None) with an unsupported type () to a Tensor. + if box is None: + box = jnp.empty((coord.shape[0], 0, 0), dtype=jnp.float64) + if fparam is None: + fparam = jnp.empty( + (coord.shape[0], self.get_dim_fparam()), dtype=jnp.float64 + ) + if aparam is None: + aparam = jnp.empty( + (coord.shape[0], coord.shape[1], self.get_dim_aparam()), + dtype=jnp.float64, + ) + return call( + coord, + atype, + box, + fparam, + aparam, ) def model_output_def(self): diff --git a/deepmd/jax/jax2tf/transform_output.py b/deepmd/jax/jax2tf/transform_output.py new file mode 100644 index 0000000000..f853744c02 --- /dev/null +++ b/deepmd/jax/jax2tf/transform_output.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from deepmd.dpmodel.output_def import ( + ModelOutputDef, + OutputVariableDef, + get_deriv_name, + get_reduce_name, +) + + +def get_leading_dims( + vv: tnp.ndarray, + vdef: OutputVariableDef, +) -> tnp.ndarray: + """Get the dimensions of nf x nloc. + + Parameters + ---------- + vv : np.ndarray + The input array from which to compute the leading dimensions. + vdef : OutputVariableDef + The output variable definition containing the shape to exclude from `vv`. + + Returns + ------- + list + A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions. + """ + vshape = tf.shape(vv) + return vshape[: (len(vshape) - len(vdef.shape))] + + +def communicate_extended_output( + model_ret: dict[str, tnp.ndarray], + model_output_def: ModelOutputDef, + mapping: tnp.ndarray, # nf x nloc + do_atomic_virial: bool = False, +) -> dict[str, tnp.ndarray]: + """Transform the output of the model network defined on + local and ghost (extended) atoms to local atoms. + + """ + new_ret = {} + for kk in model_output_def.keys_outp(): + vv = model_ret[kk] + vdef = model_output_def[kk] + new_ret[kk] = vv + if vdef.reducible: + kk_redu = get_reduce_name(kk) + new_ret[kk_redu] = model_ret[kk_redu] + kk_derv_r, kk_derv_c = get_deriv_name(kk) + mldims = tf.shape(mapping) + vldims = get_leading_dims(vv, vdef) + if vdef.r_differentiable: + if model_ret[kk_derv_r] is not None: + derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005 + indices = mapping.reshape(tf.shape(mapping)[0], -1, 1) + # concat frame idx + indices = tf.concat( + [ + tf.repeat( + tf.range(tf.shape(indices)[0], dtype=indices.dtype), + tf.shape(mapping)[1], + ).reshape(tf.shape(indices)), + indices, + ], + axis=-1, + ) + force = tf.scatter_nd( + indices, + model_ret[kk_derv_r], + tf.cast(tf.concat([vldims, derv_r_ext_dims], axis=0), tf.int64), + ) + new_ret[kk_derv_r] = force.reshape( + tf.concat([tf.shape(force)[:2], list(vdef.shape), [3]], axis=0) + ) + else: + # name holders + new_ret[kk_derv_r] = None + if vdef.c_differentiable: + assert vdef.r_differentiable + if model_ret[kk_derv_c] is not None: + derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005 + indices = mapping.reshape(tf.shape(mapping)[0], -1, 1) + # concat frame idx + indices = tf.concat( + [ + tf.repeat( + tf.range(tf.shape(indices)[0], dtype=indices.dtype), + tf.shape(mapping)[1], + ).reshape(tf.shape(indices)), + indices, + ], + axis=-1, + ) + virial = tf.scatter_nd( + indices, + model_ret[kk_derv_c], + tf.cast(tf.concat([vldims, derv_c_ext_dims], axis=0), tf.int64), + ) + new_ret[kk_derv_c] = virial.reshape( + tf.concat([tf.shape(virial)[:2], list(vdef.shape), [9]], axis=0) + ) + new_ret[kk_derv_c + "_redu"] = tnp.sum(new_ret[kk_derv_c], axis=1) + else: + new_ret[kk_derv_c] = None + new_ret[kk_derv_c + "_redu"] = None + if not do_atomic_virial: + # pop atomic virial, because it is not correctly calculated. + new_ret.pop(kk_derv_c) + return new_ret diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py index 2946f8bec7..4d59957456 100644 --- a/deepmd/jax/model/hlo.py +++ b/deepmd/jax/model/hlo.py @@ -46,6 +46,8 @@ def __init__( self, stablehlo, stablehlo_atomic_virial, + stablehlo_no_ghost, + stablehlo_atomic_virial_no_ghost, model_def_script, type_map, rcut, @@ -62,6 +64,10 @@ def __init__( self._call_lower_atomic_virial = jax_export.deserialize( stablehlo_atomic_virial ).call + self._call_lower_no_ghost = jax_export.deserialize(stablehlo_no_ghost).call + self._call_lower_atomic_virial_no_ghost = jax_export.deserialize( + stablehlo_atomic_virial_no_ghost + ).call self.stablehlo = stablehlo self.type_map = type_map self.rcut = rcut @@ -174,10 +180,16 @@ def call_lower( aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, ): - if do_atomic_virial: - call_lower = self._call_lower_atomic_virial + if extended_coord.shape[1] > nlist.shape[1]: + if do_atomic_virial: + call_lower = self._call_lower_atomic_virial + else: + call_lower = self._call_lower else: - call_lower = self._call_lower + if do_atomic_virial: + call_lower = self._call_lower_atomic_virial_no_ghost + else: + call_lower = self._call_lower_no_ghost return call_lower( extended_coord, extended_atype, diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 6ab99a81f0..1ed26f2d40 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -53,7 +53,9 @@ def deserialize_to_file(model_file: str, data: dict) -> None: nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost") - def exported_whether_do_atomic_virial(do_atomic_virial): + def exported_whether_do_atomic_virial( + do_atomic_virial: bool, has_ghost_atoms: bool + ): def call_lower_with_fixed_do_atomic_virial( coord, atype, nlist, mapping, fparam, aparam ): @@ -67,13 +69,18 @@ def call_lower_with_fixed_do_atomic_virial( do_atomic_virial=do_atomic_virial, ) + if has_ghost_atoms: + nghost_ = nghost + else: + nghost_ = 0 + return jax_export.export(jax.jit(call_lower_with_fixed_do_atomic_virial))( jax.ShapeDtypeStruct( - (nf, nloc + nghost, 3), jnp.float64 + (nf, nloc + nghost_, 3), jnp.float64 ), # extended_coord - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype + jax.ShapeDtypeStruct((nf, nloc + nghost_), jnp.int32), # extended_atype jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping + jax.ShapeDtypeStruct((nf, nloc + nghost_), jnp.int64), # mapping jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64) if model.get_dim_fparam() else None, # fparam @@ -82,18 +89,34 @@ def call_lower_with_fixed_do_atomic_virial( else None, # aparam ) - exported = exported_whether_do_atomic_virial(do_atomic_virial=False) + exported = exported_whether_do_atomic_virial( + do_atomic_virial=False, has_ghost_atoms=True + ) exported_atomic_virial = exported_whether_do_atomic_virial( - do_atomic_virial=True + do_atomic_virial=True, has_ghost_atoms=True ) serialized: bytearray = exported.serialize() serialized_atomic_virial = exported_atomic_virial.serialize() + + exported_no_ghost = exported_whether_do_atomic_virial( + do_atomic_virial=False, has_ghost_atoms=False + ) + exported_atomic_virial_no_ghost = exported_whether_do_atomic_virial( + do_atomic_virial=True, has_ghost_atoms=False + ) + serialized_no_ghost: bytearray = exported_no_ghost.serialize() + serialized_atomic_virial_no_ghost = exported_atomic_virial_no_ghost.serialize() + data = data.copy() data.setdefault("@variables", {}) data["@variables"]["stablehlo"] = np.void(serialized) data["@variables"]["stablehlo_atomic_virial"] = np.void( serialized_atomic_virial ) + data["@variables"]["stablehlo_no_ghost"] = np.void(serialized_no_ghost) + data["@variables"]["stablehlo_atomic_virial_no_ghost"] = np.void( + serialized_atomic_virial_no_ghost + ) data["constants"] = { "type_map": model.get_type_map(), "rcut": model.get_rcut(), diff --git a/pyproject.toml b/pyproject.toml index 7d64d48e80..cf2c8c3b93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -406,8 +406,10 @@ convention = "numpy" banned-module-level-imports = [ "deepmd.tf", "deepmd.pt", + "deepmd.jax", "tensorflow", "torch", + "jax", ] [tool.ruff.lint.flake8-tidy-imports.banned-api] @@ -420,7 +422,9 @@ banned-module-level-imports = [ "deepmd/jax/**" = ["TID253"] "source/tests/tf/**" = ["TID253"] "source/tests/pt/**" = ["TID253"] +"source/tests/jax/**" = ["TID253"] "source/tests/universal/pt/**" = ["TID253"] +"source/jax2tf_tests/**" = ["TID253"] "source/ipi/tests/**" = ["TID253"] "source/lmp/tests/**" = ["TID253"] "**/*.ipynb" = ["T20"] # printing in a nb file is expected diff --git a/source/jax2tf_tests/__init__.py b/source/jax2tf_tests/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/jax2tf_tests/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/jax2tf_tests/test_nlist.py b/source/jax2tf_tests/test_nlist.py new file mode 100644 index 0000000000..5b13e4231c --- /dev/null +++ b/source/jax2tf_tests/test_nlist.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + + +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from deepmd.jax.jax2tf.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.jax.jax2tf.region import ( + inter2phys, +) + +dtype = tnp.float64 + + +class TestNeighList(tf.test.TestCase): + def setUp(self): + self.nf = 3 + self.nloc = 3 + self.ns = 5 * 5 * 3 + self.nall = self.ns * self.nloc + self.cell = tnp.array([[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=dtype) + self.icoord = tnp.array([[0, 0, 0], [0, 0, 0], [0.5, 0.5, 0.1]], dtype=dtype) + self.atype = tnp.array([-1, 0, 1], dtype=tnp.int32) + [self.cell, self.icoord, self.atype] = [ + tnp.expand_dims(ii, 0) for ii in [self.cell, self.icoord, self.atype] + ] + self.coord = inter2phys(self.icoord, self.cell).reshape([-1, self.nloc * 3]) + self.cell = self.cell.reshape([-1, 9]) + [self.cell, self.coord, self.atype] = [ + tnp.tile(ii, [self.nf, 1]) for ii in [self.cell, self.coord, self.atype] + ] + self.rcut = 1.01 + self.prec = 1e-10 + self.nsel = [10, 10] + self.ref_nlist = tnp.array( + [ + [-1] * sum(self.nsel), + [1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], + [1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 2, 2, 2, 2, 2, 2, -1, -1, -1, -1], + ] + ) + + def test_build_notype(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + nlist = build_neighbor_list( + ecoord, + eatype, + self.nloc, + self.rcut, + sum(self.nsel), + distinguish_types=False, + ) + self.assertAllClose(nlist[0], nlist[1]) + nlist_mask = nlist[0] == -1 + nlist_loc = mapping[0][nlist[0]] + nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc) + self.assertAllClose( + tnp.sort(nlist_loc, axis=-1), + tnp.sort(self.ref_nlist, axis=-1), + ) + + def test_build_type(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + nlist = build_neighbor_list( + ecoord, + eatype, + self.nloc, + self.rcut, + self.nsel, + distinguish_types=True, + ) + self.assertAllClose(nlist[0], nlist[1]) + nlist_mask = nlist[0] == -1 + nlist_loc = mapping[0][nlist[0]] + nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc) + for ii in range(2): + self.assertAllClose( + tnp.sort(tnp.split(nlist_loc, self.nsel, axis=-1)[ii], axis=-1), + tnp.sort(tnp.split(self.ref_nlist, self.nsel, axis=-1)[ii], axis=-1), + ) + + def test_extend_coord(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + # expected ncopy x nloc + self.assertEqual(list(ecoord.shape), [self.nf, self.nall * 3]) + self.assertEqual(list(eatype.shape), [self.nf, self.nall]) + self.assertEqual(list(mapping.shape), [self.nf, self.nall]) + # check the nloc part is identical with original coord + self.assertAllClose( + ecoord[:, : self.nloc * 3], self.coord, rtol=self.prec, atol=self.prec + ) + # check the shift vectors are aligned with grid + shift_vec = ( + ecoord.reshape([-1, self.ns, self.nloc, 3]) + - self.coord.reshape([-1, self.nloc, 3])[:, None, :, :] + ) + shift_vec = shift_vec.reshape([-1, self.nall, 3]) + # hack!!! assumes identical cell across frames + shift_vec = tnp.matmul( + shift_vec, tf.linalg.inv(self.cell.reshape([self.nf, 3, 3])[0]) + ) + # nf x nall x 3 + shift_vec = tnp.round(shift_vec) + # check: identical shift vecs + self.assertAllClose(shift_vec[0], shift_vec[1], rtol=self.prec, atol=self.prec) + # check: shift idx aligned with grid + mm, _, cc = tf.unique_with_counts(shift_vec[0][:, 0]) + self.assertAllClose( + tnp.sort(mm), + tnp.array([-2, -1, 0, 1, 2], dtype=dtype), + rtol=self.prec, + atol=self.prec, + ) + self.assertAllClose( + cc, + tnp.array([self.ns * self.nloc // 5] * 5, dtype=tnp.int32), + rtol=self.prec, + atol=self.prec, + ) + mm, _, cc = tf.unique_with_counts(shift_vec[1][:, 1]) + self.assertAllClose( + tnp.sort(mm), + tnp.array([-2, -1, 0, 1, 2], dtype=dtype), + rtol=self.prec, + atol=self.prec, + ) + self.assertAllClose( + cc, + tnp.array([self.ns * self.nloc // 5] * 5, dtype=tnp.int32), + rtol=self.prec, + atol=self.prec, + ) + mm, _, cc = tf.unique_with_counts(shift_vec[1][:, 2]) + self.assertAllClose( + tnp.sort(mm), + tnp.array([-1, 0, 1], dtype=dtype), + rtol=self.prec, + atol=self.prec, + ) + self.assertAllClose( + cc, + tnp.array([self.ns * self.nloc // 3] * 3, dtype=tnp.int32), + rtol=self.prec, + atol=self.prec, + ) diff --git a/source/jax2tf_tests/test_region.py b/source/jax2tf_tests/test_region.py new file mode 100644 index 0000000000..2becf08c94 --- /dev/null +++ b/source/jax2tf_tests/test_region.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + + +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from deepmd.jax.jax2tf.region import ( + inter2phys, + to_face_distance, +) + +GLOBAL_SEED = 20241109 + + +class TestRegion(tf.test.TestCase): + def setUp(self): + self.cell = tnp.array( + [[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], + ) + self.cell = tnp.reshape(self.cell, [1, 1, -1, 3]) + self.cell = tnp.tile(self.cell, [4, 5, 1, 1]) + self.prec = 1e-8 + + def test_inter_to_phys(self): + rng = tf.random.Generator.from_seed(GLOBAL_SEED) + inter = rng.normal(shape=[4, 5, 3, 3]) + phys = inter2phys(inter, self.cell) + for ii in range(4): + for jj in range(5): + expected_phys = tnp.matmul(inter[ii, jj], self.cell[ii, jj]) + self.assertAllClose( + phys[ii, jj], expected_phys, rtol=self.prec, atol=self.prec + ) + + def test_to_face_dist(self): + cell0 = self.cell[0][0] + vol = tf.linalg.det(cell0) + # area of surfaces xy, xz, yz + sxy = tf.linalg.norm(tnp.cross(cell0[0], cell0[1])) + sxz = tf.linalg.norm(tnp.cross(cell0[0], cell0[2])) + syz = tf.linalg.norm(tnp.cross(cell0[1], cell0[2])) + # vol / area gives distance + dz = vol / sxy + dy = vol / sxz + dx = vol / syz + expected = tnp.array([dx, dy, dz]) + dists = to_face_distance(self.cell) + for ii in range(4): + for jj in range(5): + self.assertAllClose( + dists[ii][jj], expected, rtol=self.prec, atol=self.prec + ) diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index ca213da13c..bc9103c56e 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -73,14 +73,12 @@ def tearDown(self): shutil.rmtree(ii) @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") - @unittest.skipIf(DP_TEST_TF2_ONLY, "Conflict with TF2 eager mode.") def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() for backend_name, suffix_idx in ( - ("tensorflow", 0), + ("tensorflow", 0) if not DP_TEST_TF2_ONLY else ("jax", 0), ("pytorch", 0), ("dpmodel", 0), - ("jax", 0), ): with self.subTest(backend_name=backend_name): backend = Backend.get_backend(backend_name)() @@ -142,13 +140,16 @@ def test_deep_eval(self): nframes = self.atype.shape[0] prefix = "test_consistent_io_" + self.__class__.__name__.lower() rets = [] + rets_nopbc = [] for backend_name, suffix_idx in ( # unfortunately, jax2tf cannot work with tf v1 behaviors ("jax", 2) if DP_TEST_TF2_ONLY else ("tensorflow", 0), ("pytorch", 0), ("dpmodel", 0), - ("jax", 0), + ("jax", 0) if DP_TEST_TF2_ONLY else (None, None), ): + if backend_name is None: + continue backend = Backend.get_backend(backend_name)() if not backend.is_available(): continue @@ -182,6 +183,23 @@ def test_deep_eval(self): atomic=True, ) rets.append(ret) + ret = deep_eval.eval( + self.coords, + None, + self.atype, + fparam=fparam, + aparam=aparam, + ) + rets_nopbc.append(ret) + ret = deep_eval.eval( + self.coords, + None, + self.atype, + fparam=fparam, + aparam=aparam, + atomic=True, + ) + rets_nopbc.append(ret) for ret in rets[1:]: for vv1, vv2 in zip(rets[0], ret): if np.isnan(vv2).all(): @@ -189,6 +207,15 @@ def test_deep_eval(self): continue np.testing.assert_allclose(vv1, vv2, rtol=1e-12, atol=1e-12) + for idx, ret in enumerate(rets_nopbc[1:]): + for vv1, vv2 in zip(rets_nopbc[0], ret): + if np.isnan(vv2).all(): + # expect all nan if not supported + continue + np.testing.assert_allclose( + vv1, vv2, rtol=1e-12, atol=1e-12, err_msg=f"backend {idx+1}" + ) + class TestDeepPot(unittest.TestCase, IOTest): def setUp(self): diff --git a/source/tests/jax/__init__.py b/source/tests/jax/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/jax/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later