From cdf6d69537196d77d923cbdb7c721bd11be17656 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 25 Feb 2024 23:18:22 -0500 Subject: [PATCH 1/3] feat(pt/dpmodel): support type_one_side in se_e2_a Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/se_e2_a.py | 63 +++++++++++++------ deepmd/dpmodel/utils/exclude_mask.py | 3 + deepmd/pt/model/descriptor/se_a.py | 36 ++++++++--- .../consistent/descriptor/test_se_e2_a.py | 15 ++++- 4 files changed, 87 insertions(+), 30 deletions(-) diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index be2ed12394..0c1ad1724e 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import itertools + import numpy as np from deepmd.utils.path import ( @@ -144,8 +146,6 @@ def __init__( seed: Optional[int] = None, ) -> None: ## seed, uniform_seed, multi_task, not included. - if not type_one_side: - raise NotImplementedError("type_one_side == False not implemented") if spin is not None: raise NotImplementedError("spin is not implemented") @@ -171,10 +171,10 @@ def __init__( ndim=(1 if self.type_one_side else 2), network_type="embedding_network", ) - if not self.type_one_side: - raise NotImplementedError("type_one_side == False not implemented") - for ii in range(self.ntypes): - self.embeddings[(ii,)] = EmbeddingNet( + for embedding_idx in itertools.product( + range(self.ntypes), repeat=self.embeddings.ndim + ): + self.embeddings[embedding_idx] = EmbeddingNet( in_dim, self.neuron, self.activation_function, @@ -241,12 +241,12 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) def cal_g( self, ss, - ll, + embedding_idx, ): - nf, nloc, nnei = ss.shape[0:3] - ss = ss.reshape(nf, nloc, nnei, 1) - # nf x nloc x nnei x ng - gg = self.embeddings[(ll,)].call(ss) + nf_times_nloc, nnei = ss.shape[0:2] + ss = ss.reshape(nf_times_nloc, nnei, 1) + # (nf x nloc) x nnei x ng + gg = self.embeddings[embedding_idx].call(ss) return gg def call( @@ -292,16 +292,30 @@ def call( sec = np.append([0], np.cumsum(self.sel)) ng = self.neuron[-1] - gr = np.zeros([nf, nloc, ng, 4]) + gr = np.zeros([nf * nloc, ng, 4]) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) - for tt in range(self.ntypes): - mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]] - tr = rr[:, :, sec[tt] : sec[tt + 1], :] - tr = tr * mm[:, :, :, None] + # merge nf and nloc axis, so for type_one_side == False, + # we don't require atype is the same in all frames + exclude_mask = exclude_mask.reshape(nf * nloc, nnei) + rr = rr.reshape(nf * nloc, nnei, 4) + + for embedding_idx in itertools.product( + range(self.ntypes), repeat=self.embeddings.ndim + ): + if self.type_one_side: + (tt,) = embedding_idx + ti_mask = np.s_[:] + else: + ti, tt = embedding_idx + ti_mask = atype_ext[:, :nloc].ravel() == ti + mm = exclude_mask[ti_mask, sec[tt] : sec[tt + 1]] + tr = rr[ti_mask, sec[tt] : sec[tt + 1], :] + tr = tr * mm[:, :, None] ss = tr[..., 0:1] - gg = self.cal_g(ss, tt) - # nf x nloc x ng x 4 - gr += np.einsum("flni,flnj->flij", gg, tr) + gg = self.cal_g(ss, embedding_idx) + gr_tmp = np.einsum("lni,lnj->lij", gg, tr) + gr[ti_mask] += gr_tmp + gr = gr.reshape(nf, nloc, ng, 4) # nf x nloc x ng x 4 gr /= self.nnei gr1 = gr[:, :, : self.axis_neuron, :] @@ -313,6 +327,17 @@ def call( def serialize(self) -> dict: """Serialize the descriptor to dict.""" + if not self.type_one_side and self.exclude_types: + for embedding_idx in itertools.product(range(self.ntypes), repeat=2): + # not actually used; to match serilization data from TF to pass the test + if embedding_idx in self.emask: + for ilayer in range(len(self.neuron)): + layer = self.embeddings[embedding_idx][ilayer] + layer.w.fill(0.0) + layer.b.fill(0.0) + if layer.idt is not None: + layer.idt.fill(0.0) + return { "@class": "Descriptor", "type": "se_e2_a", diff --git a/deepmd/dpmodel/utils/exclude_mask.py b/deepmd/dpmodel/utils/exclude_mask.py index 83e3c7a363..360f190e13 100644 --- a/deepmd/dpmodel/utils/exclude_mask.py +++ b/deepmd/dpmodel/utils/exclude_mask.py @@ -115,3 +115,6 @@ def build_type_exclude_mask( type_ij = type_ij.reshape(nf, nloc * nnei) mask = self.type_mask[type_ij].reshape(nf, nloc, nnei) return mask + + def __contains__(self, item): + return item in self.exclude_types diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 0550488ecf..4ff5c2cf89 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import itertools from typing import ( ClassVar, Dict, @@ -67,6 +68,7 @@ def __init__( resnet_dt: bool = False, exclude_types: List[Tuple[int, int]] = [], old_impl: bool = False, + type_one_side: bool = True, **kwargs, ): super().__init__() @@ -82,6 +84,7 @@ def __init__( resnet_dt=resnet_dt, exclude_types=exclude_types, old_impl=old_impl, + type_one_side=type_one_side, **kwargs, ) @@ -214,7 +217,7 @@ def serialize(self) -> dict: }, ## to be updated when the options are supported. "trainable": True, - "type_one_side": True, + "type_one_side": obj.type_one_side, "spin": None, } @@ -255,6 +258,7 @@ def __init__( resnet_dt: bool = False, exclude_types: List[Tuple[int, int]] = [], old_impl: bool = False, + type_one_side: bool = True, **kwargs, ): """Construct an embedding net of type `se_a`. @@ -281,6 +285,7 @@ def __init__( self.exclude_types = exclude_types self.ntypes = len(sel) self.emask = PairExcludeMask(len(sel), exclude_types=exclude_types) + self.type_one_side = type_one_side self.sel = sel self.sec = torch.tensor( @@ -299,6 +304,10 @@ def __init__( self.filter_layers = None if self.old_impl: + if not self.type_one_side: + raise ValueError( + "The old implementation does not support type_one_side=False." + ) filter_layers = [] # TODO: remove start_index = 0 @@ -308,12 +317,12 @@ def __init__( start_index += sel[type_i] self.filter_layers_old = torch.nn.ModuleList(filter_layers) else: + ndim = 1 if self.type_one_side else 2 filter_layers = NetworkCollection( - ndim=1, ntypes=len(sel), network_type="embedding_network" + ndim=ndim, ntypes=len(sel), network_type="embedding_network" ) - # TODO: ndim=2 if type_one_side=False - for ii in range(self.ntypes): - filter_layers[(ii,)] = EmbeddingNet( + for embedding_idx in itertools.product(range(self.ntypes), repeat=ndim): + filter_layers[embedding_idx] = EmbeddingNet( 1, self.filter_neuron, activation_function=self.activation_function, @@ -471,18 +480,27 @@ def forward( ) # nfnl x nnei exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1) - for ii, ll in enumerate(self.filter_layers.networks): + for embedding_idx, ll in enumerate(self.filter_layers.networks): + if self.type_one_side: + ii = embedding_idx + ti_mask = slice(None) + else: + # ti: center atom type, ii: neighbor type... + ii = embedding_idx // self.ntypes + ti = embedding_idx % self.ntypes + ti_mask = atype.ravel().eq(ti) + print(ii, ti, ti_mask.shape, exclude_mask.shape, self.sec) # nfnl x nt - mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] + mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]] # nfnl x nt x 4 - rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] + rr = dmatrix[ti_mask, self.sec[ii] : self.sec[ii + 1], :] rr = rr * mm[:, :, None] ss = rr[:, :, :1] # nfnl x nt x ng gg = ll.forward(ss) # nfnl x 4 x ng gr = torch.matmul(rr.permute(0, 2, 1), gg) - xyz_scatter += gr + xyz_scatter[ti_mask] += gr xyz_scatter /= self.nnei xyz_scatter_1 = xyz_scatter.permute(0, 2, 1) diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index fe20278e6f..0243a77044 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -71,7 +71,7 @@ def skip_pt(self) -> bool: excluded_types, precision, ) = self.param - return not type_one_side or CommonTest.skip_pt + return CommonTest.skip_pt @property def skip_dp(self) -> bool: @@ -81,7 +81,7 @@ def skip_dp(self) -> bool: excluded_types, precision, ) = self.param - return not type_one_side or CommonTest.skip_dp + return CommonTest.skip_dp tf_class = DescrptSeATF dp_class = DescrptSeADP @@ -121,6 +121,17 @@ def setUp(self): dtype=GLOBAL_NP_FLOAT_PRECISION, ) self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + # TF se_e2_a type_one_side=False requires atype sorted + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + ) = self.param + if not type_one_side: + idx = np.argsort(self.atype) + self.atype = self.atype[idx] + self.coords = self.coords.reshape(-1, 3)[idx].ravel() def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: return self.build_tf_descriptor( From 8c4f54a1f89bddb99f905ffe4f55486f943c2d3b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 25 Feb 2024 23:27:30 -0500 Subject: [PATCH 2/3] make jit happy Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 4ff5c2cf89..fe38d081fc 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -483,13 +483,13 @@ def forward( for embedding_idx, ll in enumerate(self.filter_layers.networks): if self.type_one_side: ii = embedding_idx - ti_mask = slice(None) + # torch.jit is not happy with slice(None) + ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device) else: # ti: center atom type, ii: neighbor type... ii = embedding_idx // self.ntypes ti = embedding_idx % self.ntypes ti_mask = atype.ravel().eq(ti) - print(ii, ti, ti_mask.shape, exclude_mask.shape, self.sec) # nfnl x nt mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]] # nfnl x nt x 4 From b3160c878aaed8187cc60589bdcd70f7be872dda Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 26 Feb 2024 16:37:05 -0500 Subject: [PATCH 3/3] move to a `Network.clear()` method Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/se_e2_a.py | 7 +------ deepmd/dpmodel/utils/network.py | 9 +++++++++ deepmd/tf/descriptor/se.py | 11 ++--------- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 0c1ad1724e..97ab719c62 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -331,12 +331,7 @@ def serialize(self) -> dict: for embedding_idx in itertools.product(range(self.ntypes), repeat=2): # not actually used; to match serilization data from TF to pass the test if embedding_idx in self.emask: - for ilayer in range(len(self.neuron)): - layer = self.embeddings[embedding_idx][ilayer] - layer.w.fill(0.0) - layer.b.fill(0.0) - if layer.idt is not None: - layer.idt.fill(0.0) + self.embeddings[embedding_idx].clear() return { "@class": "Descriptor", diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index c0a62c9a3e..2133bc4889 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -396,6 +396,15 @@ def call(self, x): x = layer(x) return x + def clear(self): + """Clear the network parameters to zero.""" + for layer in self.layers: + layer.w.fill(0.0) + if layer.b is not None: + layer.b.fill(0.0) + if layer.idt is not None: + layer.idt.fill(0.0) + return NN diff --git a/deepmd/tf/descriptor/se.py b/deepmd/tf/descriptor/se.py index 98d98cd467..857c8b28df 100644 --- a/deepmd/tf/descriptor/se.py +++ b/deepmd/tf/descriptor/se.py @@ -231,15 +231,8 @@ def serialize_network( resnet_dt=resnet_dt, precision=self.precision.name, ) - for layer in range(len(neuron)): - embeddings[(type_i, type_j)][layer]["w"][:] = 0.0 - embeddings[(type_i, type_j)][layer]["b"][:] = 0.0 - if embeddings[(type_i, type_j)][layer]["idt"] is not None: - embeddings[(type_i, type_j)][layer]["idt"][:] = 0.0 - embeddings[(type_j, type_i)][layer]["w"][:] = 0.0 - embeddings[(type_j, type_i)][layer]["b"][:] = 0.0 - if embeddings[(type_j, type_i)][layer]["idt"] is not None: - embeddings[(type_j, type_i)][layer]["idt"][:] = 0.0 + embeddings[(type_i, type_j)].clear() + embeddings[(type_j, type_i)].clear() if suffix != "": embedding_net_pattern = (