diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 29577ef79e..d29ce8862e 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -7,6 +7,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( @@ -14,6 +15,9 @@ PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils import ( EmbeddingNet, EnvMat, @@ -186,15 +190,15 @@ def __init__( self.reinit_exclude(exclude_types) in_dim = 1 # not considiering type embedding - self.embeddings = NetworkCollection( + embeddings = NetworkCollection( ntypes=self.ntypes, ndim=(1 if self.type_one_side else 2), network_type="embedding_network", ) for ii, embedding_idx in enumerate( - itertools.product(range(self.ntypes), repeat=self.embeddings.ndim) + itertools.product(range(self.ntypes), repeat=embeddings.ndim) ): - self.embeddings[embedding_idx] = EmbeddingNet( + embeddings[embedding_idx] = EmbeddingNet( in_dim, self.neuron, self.activation_function, @@ -202,8 +206,9 @@ def __init__( self.precision, seed=child_seed(seed, ii), ) + self.embeddings = embeddings self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) - self.nnei = np.sum(self.sel) + self.nnei = np.sum(self.sel).item() self.davg = np.zeros( [self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision] ) @@ -211,6 +216,7 @@ def __init__( [self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision] ) self.orig_sel = self.sel + self.sel_cumsum = [0, *np.cumsum(self.sel).tolist()] def __setitem__(self, key, value): if key in ("avg", "data_avg", "davg"): @@ -321,8 +327,9 @@ def cal_g( ss, embedding_idx, ): + xp = array_api_compat.array_namespace(ss) nf_times_nloc, nnei = ss.shape[0:2] - ss = ss.reshape(nf_times_nloc, nnei, 1) + ss = xp.reshape(ss, (nf_times_nloc, nnei, 1)) # (nf x nloc) x nnei x ng gg = self.embeddings[embedding_idx].call(ss) return gg @@ -444,8 +451,8 @@ def serialize(self) -> dict: "env_mat": self.env_mat.serialize(), "embeddings": self.embeddings.serialize(), "@variables": { - "davg": self.davg, - "dstd": self.dstd, + "davg": to_numpy_array(self.davg), + "dstd": to_numpy_array(self.dstd), }, "type_map": self.type_map, } @@ -497,3 +504,89 @@ def update_sel( train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False ) return local_jdata_cpy, min_nbor_dist + + +class DescrptSeAArrayAPI(DescrptSeA): + def call( + self, + coord_ext, + atype_ext, + nlist, + mapping: Optional[np.ndarray] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping from extended to lcoal region. not used by this descriptor. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + this descriptor returns None + h2 + The rotationally equivariant pair-partical representation. + this descriptor returns None + sw + The smooth switch function. + """ + if not self.type_one_side: + raise NotImplementedError( + "type_one_side == False is not supported in DescrptSeAArrayAPI" + ) + del mapping + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) + input_dtype = coord_ext.dtype + # nf x nloc x nnei x 4 + rr, diff, ww = self.env_mat.call( + coord_ext, atype_ext, nlist, self.davg, self.dstd + ) + nf, nloc, nnei, _ = rr.shape + sec = xp.asarray(self.sel_cumsum) + + ng = self.neuron[-1] + gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype) + exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) + # merge nf and nloc axis, so for type_one_side == False, + # we don't require atype is the same in all frames + exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) + rr = xp.reshape(rr, (nf * nloc, nnei, 4)) + rr = xp.astype(rr, self.dstd.dtype) + + for embedding_idx in itertools.product( + range(self.ntypes), repeat=self.embeddings.ndim + ): + (tt,) = embedding_idx + mm = exclude_mask[:, sec[tt] : sec[tt + 1]] + tr = rr[:, sec[tt] : sec[tt + 1], :] + tr = tr * xp.astype(mm[:, :, None], tr.dtype) + ss = tr[..., 0:1] + gg = self.cal_g(ss, embedding_idx) + # gr_tmp = xp.einsum("lni,lnj->lij", gg, tr) + gr_tmp = xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1) + gr += gr_tmp + gr = xp.reshape(gr, (nf, nloc, ng, 4)) + # nf x nloc x ng x 4 + gr /= self.nnei + gr1 = gr[:, :, : self.axis_neuron, :] + # nf x nloc x ng x ng1 + # grrg = xp.einsum("flid,fljd->flij", gr, gr1) + grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4) + # nf x nloc x (ng x ng1) + grrg = xp.astype( + xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), input_dtype + ) + return grrg, gr[..., 1:], None, None, ww diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index 4806fa4cd8..c56f1bc061 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -163,20 +163,20 @@ def nlist_distinguish_types( xp = array_api_compat.array_namespace(nlist, atype) nf, nloc, _ = nlist.shape ret_nlist = [] - tmp_atype = xp.tile(atype[:, None], [1, nloc, 1]) + tmp_atype = xp.tile(atype[:, None, :], (1, nloc, 1)) mask = nlist == -1 - tnlist_0 = nlist.copy() - tnlist_0[mask] = 0 - tnlist = xp_take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze() - tnlist = xp.where(mask, -1, tnlist) - snsel = tnlist.shape[2] + tnlist_0 = xp.where(mask, xp.zeros_like(nlist), nlist) + tnlist = xp_take_along_axis(tmp_atype, tnlist_0, axis=2) + tnlist = xp.where(mask, xp.full_like(tnlist, -1), tnlist) for ii, ss in enumerate(sel): - pick_mask = (tnlist == ii).astype(xp.int32) - sorted_indices = xp.argsort(-pick_mask, kind="stable", axis=-1) + pick_mask = xp.astype(tnlist == ii, xp.int32) + sorted_indices = xp.argsort(-pick_mask, stable=True, axis=-1) pick_mask_sorted = -xp.sort(-pick_mask, axis=-1) inlist = xp_take_along_axis(nlist, sorted_indices, axis=2) - inlist = xp.where(~pick_mask_sorted.astype(bool), -1, inlist) - ret_nlist.append(xp.split(inlist, [ss, snsel - ss], axis=-1)[0]) + inlist = xp.where( + ~xp.astype(pick_mask_sorted, xp.bool), xp.full_like(inlist, -1), inlist + ) + ret_nlist.append(inlist[..., :ss]) ret = xp.concat(ret_nlist, axis=-1) return ret diff --git a/deepmd/jax/descriptor/se_e2_a.py b/deepmd/jax/descriptor/se_e2_a.py new file mode 100644 index 0000000000..a60a4e9af1 --- /dev/null +++ b/deepmd/jax/descriptor/se_e2_a.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP +from deepmd.jax.common import ( + flax_module, + to_jax_array, +) +from deepmd.jax.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.jax.utils.network import ( + NetworkCollection, +) + + +@flax_module +class DescrptSeA(DescrptSeADP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"dstd", "davg"}: + value = to_jax_array(value) + elif name in {"embeddings"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/descriptor/se_e2_a.py b/source/tests/array_api_strict/descriptor/se_e2_a.py new file mode 100644 index 0000000000..654b9f8925 --- /dev/null +++ b/source/tests/array_api_strict/descriptor/se_e2_a.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.exclude_mask import ( + PairExcludeMask, +) +from ..utils.network import ( + NetworkCollection, +) + + +class DescrptSeA(DescrptSeADP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"dstd", "davg"}: + value = to_array_api_strict_array(value) + elif name in {"embeddings"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index 2563ee1d6d..286703e21d 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -33,6 +35,17 @@ descrpt_se_a_args, ) +if INSTALLED_JAX: + from deepmd.jax.descriptor.se_e2_a import DescrptSeA as DescrptSeAJAX +else: + DescrptSeAJAX = None +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.se_e2_a import ( + DescrptSeA as DescrptSeAArrayAPIStrict, + ) +else: + DescrptSeAArrayAPIStrict = None + @parameterized( (True, False), # resnet_dt @@ -98,9 +111,33 @@ def skip_tf(self) -> bool: ) = self.param return env_protection != 0.0 + @property + def skip_jax(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + env_protection, + ) = self.param + return not type_one_side or not INSTALLED_JAX + + @property + def skip_array_api_strict(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + env_protection, + ) = self.param + return not type_one_side or not INSTALLED_ARRAY_API_STRICT + tf_class = DescrptSeATF dp_class = DescrptSeADP pt_class = DescrptSeAPT + jax_class = DescrptSeAJAX + array_api_strict_class = DescrptSeAArrayAPIStrict args = descrpt_se_a_args() def setUp(self): @@ -177,6 +214,24 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_descriptor( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return self.eval_array_api_strict_descriptor( + array_api_strict_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],)