From d48f84a30dd493fb9cd292ce4da9da84387866dc Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 6 May 2024 15:50:59 +0800 Subject: [PATCH] feat: Support stripped type embedding in DPA1 of PT/DP (#3712) This PR supports stripped type embedding in DPA1 of PT/DP: - Remove `stripped_type_embedding` params in all classes and use `tebd_input_mode` == "strip" instead. - Add stripped type embedding inplementation for DPA1 of PT/DP. - Add serialize and deserialize for stripped type embedding. Note: - Old TF inplementation has not consistent behaivior when `type_one_side`==True and `tebd_input_mode` == "strip", it always uses two_side type stripped embeddings input, which is also inconsistent with `DescrptSeAEbdV2` in TF (but the training still works and only raise `NotImplementedError` when doing serialization now) may need support from @nahso . - Old TF inplementation `init_variables` will not init `idt` weights from graph for `two_side_embeeding_net_variables` (fixed), I'm surprised that no ut failed before (maybe all tests use `resnet_dt` == False). - The TF implementation of `DescrptSeAtten` does not support serialization when `tebd_input_mode` == "strip". This limitation arises because the shape of `type_embedding` cannot be determined after init, as it is decided at runtime. While the consistent version `DescrptDPA1Compat` is compatible with this configuration. ## Summary by CodeRabbit - **New Features** - Enhanced model flexibility with new type embedding input modes: `concat` and `strip`. - **Bug Fixes** - Improved model compression logic alignment with new type embedding modes for more efficient operations. - **Documentation** - Updated documentation to explain the impact of new type embedding input modes on model descriptors. - **Tests** - Adjusted test cases to reflect changes in type embedding input modes for robust testing. --------- Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- deepmd/dpmodel/descriptor/dpa1.py | 92 +++++-- deepmd/pt/model/descriptor/dpa1.py | 42 ++-- deepmd/pt/model/descriptor/se_atten.py | 52 +++- deepmd/tf/descriptor/se_a.py | 9 +- deepmd/tf/descriptor/se_a_ebd_v2.py | 4 +- deepmd/tf/descriptor/se_a_mask.py | 4 +- deepmd/tf/descriptor/se_atten.py | 224 ++++++++++++++++-- deepmd/tf/descriptor/se_atten_v2.py | 2 +- deepmd/tf/nvnmd/data/data.py | 2 +- deepmd/tf/utils/graph.py | 74 +++--- deepmd/utils/argcheck.py | 25 +- doc/model/train-se-atten.md | 2 +- .../water/se_atten_dpa1_compat/input.json | 2 +- .../tests/consistent/descriptor/test_dpa1.py | 3 +- source/tests/pt/model/test_dpa1.py | 24 +- source/tests/tf/test_data_large_batch.py | 4 +- source/tests/tf/test_descrpt_se_atten.py | 4 +- source/tests/tf/test_finetune_se_atten.py | 6 +- .../tests/tf/test_init_frz_model_se_atten.py | 6 +- .../tf/test_model_compression_se_atten.py | 4 +- source/tests/tf/test_model_se_atten.py | 12 +- 21 files changed, 462 insertions(+), 135 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 8833348b52..39d773e3c6 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -135,8 +135,9 @@ class DescrptDPA1(NativeOP, BaseDescriptor): tebd_dim: int Dimension of the type embedding tebd_input_mode: str - The way to mix the type embeddings. Supported options are `concat`. - (TODO need to support stripped_type_embedding option) + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) @@ -182,6 +183,12 @@ class DescrptDPA1(NativeOP, BaseDescriptor): Whether to use smooth process in attention weights calculation. concat_output_tebd: bool Whether to concat type embedding at the output of the descriptor. + stripped_type_embedding: bool, Optional + (Deprecated, kept only for compatibility.) + Whether to strip the type embedding into a separate embedding network. + Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'. + Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'. + The default value is `None`, which means the `tebd_input_mode` setting will be used instead. spin (Only support None to keep consistent with other backend references.) (Not used in this version. Not-none option is not implemented.) @@ -189,9 +196,6 @@ class DescrptDPA1(NativeOP, BaseDescriptor): Limitations ----------- - The currently implementation does not support the following features - 1. tebd_input_mode != 'concat' - The currently implementation will not support the following deprecated features 1. spin is not None 2. attn_mask == True @@ -233,19 +237,21 @@ def __init__( smooth_type_embedding: bool = True, concat_output_tebd: bool = True, spin: Optional[Any] = None, + stripped_type_embedding: Optional[bool] = None, # consistent with argcheck, not used though seed: Optional[int] = None, ) -> None: ## seed, uniform_seed, multi_task, not included. + # Ensure compatibility with the deprecated stripped_type_embedding option. + if stripped_type_embedding is not None: + # Use the user-set stripped_type_embedding parameter first + tebd_input_mode = "strip" if stripped_type_embedding else "concat" if spin is not None: raise NotImplementedError("old implementation of spin is not supported.") if attn_mask: raise NotImplementedError( "old implementation of attn_mask is not supported." ) - # TODO - if tebd_input_mode != "concat": - raise NotImplementedError("tebd_input_mode != 'concat' not implemented") # to keep consistent with default value in this backends if ln_eps is None: ln_eps = 1e-5 @@ -290,25 +296,38 @@ def __init__( activation_function="Linear", precision=precision, ) + self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2 if self.tebd_input_mode in ["concat"]: - if not self.type_one_side: - in_dim = 1 + self.tebd_dim * 2 - else: - in_dim = 1 + self.tebd_dim + self.embd_input_dim = 1 + self.tebd_dim_input else: - in_dim = 1 + self.embd_input_dim = 1 self.embeddings = NetworkCollection( ndim=0, ntypes=self.ntypes, network_type="embedding_network", ) self.embeddings[0] = EmbeddingNet( - in_dim, + self.embd_input_dim, self.neuron, self.activation_function, self.resnet_dt, self.precision, ) + if self.tebd_input_mode in ["strip"]: + self.embeddings_strip = NetworkCollection( + ndim=0, + ntypes=self.ntypes, + network_type="embedding_network", + ) + self.embeddings_strip[0] = EmbeddingNet( + self.tebd_dim_input, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + ) + else: + self.embeddings_strip = None self.dpa1_attention = NeighborGatedAttention( self.attn_layer, self.nnei, @@ -410,6 +429,18 @@ def cal_g( gg = self.embeddings[embedding_idx].call(ss) return gg + def cal_g_strip( + self, + ss, + embedding_idx, + ): + assert self.embeddings_strip is not None + nfnl, nnei = ss.shape[0:2] + ss = ss.reshape(nfnl, nnei, -1) + # nfnl x nnei x ng + gg = self.embeddings_strip[embedding_idx].call(ss) + return gg + def reinit_exclude( self, exclude_types: List[Tuple[int, int]] = [], @@ -500,11 +531,28 @@ def call( else: # nfnl x nnei x (1 + tebd_dim) ss = np.concatenate([ss, atype_embd_nlist], axis=-1) + # calculate gg + # nfnl x nnei x ng + gg = self.cal_g(ss, 0) + elif self.tebd_input_mode in ["strip"]: + # nfnl x nnei x ng + gg_s = self.cal_g(ss, 0) + assert self.embeddings_strip is not None + if not self.type_one_side: + # nfnl x nnei x (tebd_dim * 2) + tt = np.concatenate([atype_embd_nlist, atype_embd_nnei], axis=-1) + else: + # nfnl x nnei x tebd_dim + tt = atype_embd_nlist + # nfnl x nnei x ng + gg_t = self.cal_g_strip(tt, 0) + if self.smooth: + gg_t = gg_t * sw.reshape(-1, self.nnei, 1) + # nfnl x nnei x ng + gg = gg_s * gg_t + gg_s else: raise NotImplementedError - # calculate gg - gg = self.cal_g(ss, 0) input_r = dmatrix.reshape(-1, nnei, 4)[:, :, 1:4] / np.maximum( np.linalg.norm( dmatrix.reshape(-1, nnei, 4)[:, :, 1:4], axis=-1, keepdims=True @@ -532,7 +580,7 @@ def call( def serialize(self) -> dict: """Serialize the descriptor to dict.""" - return { + data = { "@class": "Descriptor", "type": "dpa1", "@version": 1, @@ -575,6 +623,9 @@ def serialize(self) -> dict: "trainable": True, "spin": None, } + if self.tebd_input_mode in ["strip"]: + data.update({"embeddings_strip": self.embeddings_strip.serialize()}) + return data @classmethod def deserialize(cls, data: dict) -> "DescrptDPA1": @@ -588,11 +639,18 @@ def deserialize(cls, data: dict) -> "DescrptDPA1": type_embedding = data.pop("type_embedding") attention_layers = data.pop("attention_layers") env_mat = data.pop("env_mat") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + embeddings_strip = data.pop("embeddings_strip") + else: + embeddings_strip = None obj = cls(**data) obj["davg"] = variables["davg"] obj["dstd"] = variables["dstd"] obj.embeddings = NetworkCollection.deserialize(embeddings) + if tebd_input_mode in ["strip"]: + obj.embeddings_strip = NetworkCollection.deserialize(embeddings_strip) obj.type_embedding = TypeEmbedNet.deserialize(type_embedding) obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers) return obj diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 71c8a13f46..0b28e388d5 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -118,8 +118,9 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): tebd_dim: int Dimension of the type embedding tebd_input_mode: str - The way to mix the type embeddings. Supported options are `concat`. - (TODO need to support stripped_type_embedding option) + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) @@ -165,6 +166,12 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): Whether to use smooth process in attention weights calculation. concat_output_tebd: bool Whether to concat type embedding at the output of the descriptor. + stripped_type_embedding: bool, Optional + (Deprecated, kept only for compatibility.) + Whether to strip the type embedding into a separate embedding network. + Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'. + Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'. + The default value is `None`, which means the `tebd_input_mode` setting will be used instead. spin (Only support None to keep consistent with other backend references.) (Not used in this version. Not-none option is not implemented.) @@ -172,9 +179,6 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): Limitations ----------- - The currently implementation does not support the following features - 1. tebd_input_mode != 'concat' - The currently implementation will not support the following deprecated features 1. spin is not None 2. attn_mask == True @@ -196,8 +200,7 @@ def __init__( axis_neuron: int = 16, tebd_dim: int = 8, tebd_input_mode: str = "concat", - # set_davg_zero: bool = False, - set_davg_zero: bool = True, # TODO + set_davg_zero: bool = True, attn: int = 128, attn_layer: int = 2, attn_dotr: bool = True, @@ -216,25 +219,24 @@ def __init__( ln_eps: Optional[float] = 1e-5, smooth_type_embedding: bool = True, type_one_side: bool = False, + stripped_type_embedding: Optional[bool] = None, # not implemented - stripped_type_embedding: bool = False, spin=None, type: Optional[str] = None, seed: Optional[int] = None, old_impl: bool = False, ): super().__init__() - if stripped_type_embedding: - raise NotImplementedError("stripped_type_embedding is not supported.") + # Ensure compatibility with the deprecated stripped_type_embedding option. + if stripped_type_embedding is not None: + # Use the user-set stripped_type_embedding parameter first + tebd_input_mode = "strip" if stripped_type_embedding else "concat" if spin is not None: raise NotImplementedError("old implementation of spin is not supported.") if attn_mask: raise NotImplementedError( "old implementation of attn_mask is not supported." ) - # TODO - if tebd_input_mode != "concat": - raise NotImplementedError("tebd_input_mode != 'concat' not implemented") # to keep consistent with default value in this backends if ln_eps is None: ln_eps = 1e-5 @@ -377,7 +379,7 @@ def set_stat_mean_and_stddev( def serialize(self) -> dict: obj = self.se_atten - return { + data = { "@class": "Descriptor", "type": "dpa1", "@version": 1, @@ -420,6 +422,9 @@ def serialize(self) -> dict: "trainable": True, "spin": None, } + if obj.tebd_input_mode in ["strip"]: + data.update({"embeddings_strip": obj.filter_layers_strip.serialize()}) + return data @classmethod def deserialize(cls, data: dict) -> "DescrptDPA1": @@ -432,6 +437,11 @@ def deserialize(cls, data: dict) -> "DescrptDPA1": type_embedding = data.pop("type_embedding") attention_layers = data.pop("attention_layers") env_mat = data.pop("env_mat") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + embeddings_strip = data.pop("embeddings_strip") + else: + embeddings_strip = None obj = cls(**data) def t_cvt(xx): @@ -443,6 +453,10 @@ def t_cvt(xx): obj.se_atten["davg"] = t_cvt(variables["davg"]) obj.se_atten["dstd"] = t_cvt(variables["dstd"]) obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings) + if tebd_input_mode in ["strip"]: + obj.se_atten.filter_layers_strip = NetworkCollection.deserialize( + embeddings_strip + ) obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize( attention_layers ) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 66da86ce29..9bf4788bf2 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -106,8 +106,9 @@ def __init__( tebd_dim : int Dimension of the type embedding tebd_input_mode : str - The way to mix the type embeddings. Supported options are `concat`. - (TODO need to support stripped_type_embedding option) + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. resnet_dt : bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) @@ -191,6 +192,9 @@ def __init__( # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) if self.old_impl: + assert self.tebd_input_mode in [ + "concat" + ], "Old implementation does not support tebd_input_mode != 'concat'." self.dpa1_attention = NeighborWiseAttention( self.attn_layer, self.nnei, @@ -230,16 +234,15 @@ def __init__( ) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) + self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2 if self.tebd_input_mode in ["concat"]: - if not self.type_one_side: - self.embd_input_dim = 1 + self.tebd_dim * 2 - else: - self.embd_input_dim = 1 + self.tebd_dim + self.embd_input_dim = 1 + self.tebd_dim_input else: self.embd_input_dim = 1 self.filter_layers_old = None self.filter_layers = None + self.filter_layers_strip = None if self.old_impl: filter_layers = [] one = TypeFilter( @@ -265,6 +268,18 @@ def __init__( resnet_dt=self.resnet_dt, ) self.filter_layers = filter_layers + if self.tebd_input_mode in ["strip"]: + filter_layers_strip = NetworkCollection( + ndim=0, ntypes=self.ntypes, network_type="embedding_network" + ) + filter_layers_strip[0] = EmbeddingNet( + self.tebd_dim_input, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + ) + self.filter_layers_strip = filter_layers_strip self.stats = None def get_rcut(self) -> float: @@ -498,19 +513,36 @@ def forward( rr = dmatrix rr = rr * exclude_mask[:, :, None] ss = rr[:, :, :1] + nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) + atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim) if self.tebd_input_mode in ["concat"]: - nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) - atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim) if not self.type_one_side: # nfnl x nnei x (1 + tebd_dim * 2) ss = torch.concat([ss, nlist_tebd, atype_tebd], dim=2) else: # nfnl x nnei x (1 + tebd_dim) ss = torch.concat([ss, nlist_tebd], dim=2) + # nfnl x nnei x ng + gg = self.filter_layers.networks[0](ss) + elif self.tebd_input_mode in ["strip"]: + # nfnl x nnei x ng + gg_s = self.filter_layers.networks[0](ss) + assert self.filter_layers_strip is not None + if not self.type_one_side: + # nfnl x nnei x (tebd_dim * 2) + tt = torch.concat([nlist_tebd, atype_tebd], dim=2) + else: + # nfnl x nnei x tebd_dim + tt = nlist_tebd + # nfnl x nnei x ng + gg_t = self.filter_layers_strip.networks[0](tt) + if self.smooth: + gg_t = gg_t * sw.reshape(-1, self.nnei, 1) + # nfnl x nnei x ng + gg = gg_s * gg_t + gg_s else: raise NotImplementedError - # nfnl x nnei x ng - gg = self.filter_layers._networks[0](ss) + input_r = torch.nn.functional.normalize( dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 ) diff --git a/deepmd/tf/descriptor/se_a.py b/deepmd/tf/descriptor/se_a.py index a3c96d3b9f..6684e21522 100644 --- a/deepmd/tf/descriptor/se_a.py +++ b/deepmd/tf/descriptor/se_a.py @@ -183,7 +183,7 @@ def __init__( uniform_seed: bool = False, multi_task: bool = False, spin: Optional[Spin] = None, - stripped_type_embedding: bool = False, + tebd_input_mode: str = "concat", env_protection: float = 0.0, # not implement!! **kwargs, ) -> None: @@ -194,6 +194,8 @@ def __init__( ) if env_protection != 0.0: raise NotImplementedError("env_protection != 0.0 is not supported.") + # to be compat with old option of `stripped_type_embedding` + stripped_type_embedding = tebd_input_mode == "strip" self.sel_a = sel self.rcut_r = rcut self.rcut_r_smth = rcut_smth @@ -1052,7 +1054,7 @@ def _filter_lower( ) if self.compress: raise RuntimeError( - "compression of type embedded descriptor is not supported when stripped_type_embedding == False" + "compression of type embedded descriptor is not supported when tebd_input_mode is not set to 'strip'" ) # natom x 4 x outputs_size if nvnmd_cfg.enable: @@ -1357,7 +1359,6 @@ def init_variables( graph_def, suffix, get_extra_embedding_net_suffix(self.type_one_side), - self.layer_size, ) ) @@ -1422,7 +1423,7 @@ def serialize(self, suffix: str = "") -> dict: ) if self.stripped_type_embedding: raise NotImplementedError( - "stripped_type_embedding is unsupported by the native model" + "Serialization is unsupported when tebd_input_mode is set to 'strip'" ) if (self.original_sel != self.sel_a).any(): raise NotImplementedError( diff --git a/deepmd/tf/descriptor/se_a_ebd_v2.py b/deepmd/tf/descriptor/se_a_ebd_v2.py index 0d2acbc9d5..9b92931b7f 100644 --- a/deepmd/tf/descriptor/se_a_ebd_v2.py +++ b/deepmd/tf/descriptor/se_a_ebd_v2.py @@ -24,7 +24,7 @@ class DescrptSeAEbdV2(DescrptSeA): r"""A compressible se_a_ebd model. - This model is a warpper for DescriptorSeA, which set stripped_type_embedding=True. + This model is a warpper for DescriptorSeA, which set tebd_input_mode='strip'. """ def __init__( @@ -65,6 +65,6 @@ def __init__( uniform_seed=uniform_seed, multi_task=multi_task, spin=spin, - stripped_type_embedding=True, + tebd_input_mode="strip", **kwargs, ) diff --git a/deepmd/tf/descriptor/se_a_mask.py b/deepmd/tf/descriptor/se_a_mask.py index d1ae5d7bad..e78dfba461 100644 --- a/deepmd/tf/descriptor/se_a_mask.py +++ b/deepmd/tf/descriptor/se_a_mask.py @@ -128,7 +128,7 @@ def __init__( activation_function: str = "tanh", precision: str = "default", uniform_seed: bool = False, - stripped_type_embedding: bool = False, + tebd_input_mode: str = "concat", **kwargs, ) -> None: """Constructor.""" @@ -160,6 +160,8 @@ def __init__( # numb of neighbors and numb of descrptors self.nnei_a = np.cumsum(self.sel_a)[-1] self.nnei = self.nnei_a + # to be compat with old option of `stripped_type_embedding` + stripped_type_embedding = tebd_input_mode == "strip" self.stripped_type_embedding = stripped_type_embedding self.ndescrpt_a = self.nnei_a * 4 diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 9b13c3944b..bc7315e66a 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -20,8 +20,10 @@ EnvMat, ) from deepmd.dpmodel.utils.network import ( + EmbeddingNet, LayerNorm, NativeLayer, + NetworkCollection, ) from deepmd.tf.common import ( cast_precision, @@ -116,7 +118,9 @@ class DescrptSeAtten(DescrptSeA): seed: int, Optional Random seed for initializing the network parameters. type_one_side: bool - Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1. @@ -140,14 +144,22 @@ class DescrptSeAtten(DescrptSeA): The epsilon value for layer normalization. multi_task: bool If the model has multi fitting nets to train. - stripped_type_embedding: bool - Whether to strip the type embedding into a separated embedding network. - Default value will be True in `se_atten_v2` descriptor. + tebd_input_mode: str + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. + Default value will be `strip` in `se_atten_v2` descriptor. smooth_type_embedding: bool Whether to use smooth process in attention weights calculation. And when using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True. Default value will be True in `se_atten_v2` descriptor. + stripped_type_embedding: bool, Optional + (Deprecated, kept only for compatibility.) + Whether to strip the type embedding into a separate embedding network. + Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'. + Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'. + The default value is `None`, which means the `tebd_input_mode` setting will be used instead. Raises ------ @@ -177,8 +189,8 @@ def __init__( attn_dotr: bool = True, attn_mask: bool = False, multi_task: bool = False, - stripped_type_embedding: bool = False, smooth_type_embedding: bool = False, + tebd_input_mode: str = "concat", # not implemented scaling_factor=1.0, normalize=True, @@ -187,8 +199,15 @@ def __init__( ln_eps: Optional[float] = 1e-3, concat_output_tebd: bool = True, env_protection: float = 0.0, # not implement!! + stripped_type_embedding: Optional[bool] = None, **kwargs, ) -> None: + # Ensure compatibility with the deprecated stripped_type_embedding option. + if stripped_type_embedding is None: + stripped_type_embedding = tebd_input_mode == "strip" + else: + # Use the user-set stripped_type_embedding parameter first + tebd_input_mode = "strip" if stripped_type_embedding else "concat" if not set_davg_zero and not ( stripped_type_embedding and smooth_type_embedding ): @@ -239,6 +258,7 @@ def __init__( if ntypes == 0: raise ValueError("`model/type_map` is not set or empty!") self.stripped_type_embedding = stripped_type_embedding + self.tebd_input_mode = tebd_input_mode self.smooth = smooth_type_embedding self.trainable_ln = trainable_ln self.ln_eps = ln_eps @@ -1368,7 +1388,6 @@ def compat_ln_pattern(old_key): graph_def, suffix, get_extra_embedding_net_suffix(type_one_side=False), - self.layer_size, ) ) @@ -1577,6 +1596,89 @@ def serialize_attention_layers( ) return data + def serialize_network_strip( + self, + ntypes: int, + ndim: int, + in_dim: int, + neuron: List[int], + activation_function: str, + resnet_dt: bool, + variables: dict, + suffix: str = "", + type_one_side: bool = False, + ) -> dict: + """Serialize network. + + Parameters + ---------- + ntypes : int + The number of types + ndim : int + The dimension of elements + in_dim : int + The input dimension + neuron : List[int] + The neuron list + activation_function : str + The activation function + resnet_dt : bool + Whether to use resnet + variables : dict + The input variables + suffix : str, optional + The suffix of the scope + type_one_side : bool, optional + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + + Returns + ------- + dict + The converted network data + """ + assert ndim == 0, "only supports descriptors with type embedding!" + embeddings = NetworkCollection( + ntypes=ntypes, + ndim=ndim, + network_type="embedding_network", + ) + name_suffix = get_extra_embedding_net_suffix(type_one_side=type_one_side) + embedding_net_pattern_strip = str( + rf"filter_type_(all)/(matrix)_(\d+){name_suffix}|" + rf"filter_type_(all)/(bias)_(\d+){name_suffix}|" + rf"filter_type_(all)/(idt)_(\d+){name_suffix}|" + )[:-1] + if suffix != "": + embedding_net_pattern = ( + embedding_net_pattern_strip.replace("/(idt)", suffix + "/(idt)") + .replace("/(bias)", suffix + "/(bias)") + .replace("/(matrix)", suffix + "/(matrix)") + ) + else: + embedding_net_pattern = embedding_net_pattern_strip + for key, value in variables.items(): + m = re.search(embedding_net_pattern, key) + m = [mm for mm in m.groups() if mm is not None] + layer_idx = int(m[2]) - 1 + weight_name = m[1] + network_idx = () + if embeddings[network_idx] is None: + # initialize the network if it is not initialized + embeddings[network_idx] = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function=activation_function, + resnet_dt=resnet_dt, + precision=self.precision.name, + ) + assert embeddings[network_idx] is not None + if weight_name == "idt": + value = value.ravel() + embeddings[network_idx][layer_idx][weight_name] = value + return embeddings.serialize() + @classmethod def deserialize_attention_layers(cls, data: dict, suffix: str = "") -> dict: """Deserialize attention layers. @@ -1653,6 +1755,53 @@ def deserialize_attention_layers(cls, data: dict, suffix: str = "") -> dict: ] = layer_norm["matrix"] return attention_layer_variables + @classmethod + def deserialize_network_strip( + cls, data: dict, suffix: str = "", type_one_side: bool = False + ) -> dict: + """Deserialize network. + + Parameters + ---------- + data : dict + The input network data + suffix : str, optional + The suffix of the scope + type_one_side : bool, optional + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + + Returns + ------- + variables : dict + The input variables + """ + embedding_net_variables = {} + embeddings = NetworkCollection.deserialize(data) + assert embeddings.ndim == 0, "only supports descriptors with type embedding!" + name_suffix = get_extra_embedding_net_suffix(type_one_side=type_one_side) + net_idx = () + network = embeddings[net_idx] + assert network is not None + for layer_idx, layer in enumerate(network.layers): + embedding_net_variables[ + f"filter_type_all{suffix}/matrix_{layer_idx + 1}{name_suffix}" + ] = layer.w + embedding_net_variables[ + f"filter_type_all{suffix}/bias_{layer_idx + 1}{name_suffix}" + ] = layer.b + if layer.idt is not None: + embedding_net_variables[ + f"filter_type_all{suffix}/idt_{layer_idx + 1}{name_suffix}" + ] = layer.idt.reshape(1, -1) + else: + # prevent keyError + embedding_net_variables[ + f"filter_type_all{suffix}/idt_{layer_idx + 1}{name_suffix}" + ] = 0.0 + return embedding_net_variables + @classmethod def deserialize(cls, data: dict, suffix: str = ""): """Deserialize the model. @@ -1681,6 +1830,11 @@ def deserialize(cls, data: dict, suffix: str = ""): ) data.pop("env_mat") variables = data.pop("@variables") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + raise ValueError( + "Deserialization is unsupported for `tebd_input_mode='strip'` in the native model." + ) descriptor = cls(**data) descriptor.embedding_net_variables = embedding_net_variables descriptor.attention_layer_variables = attention_layer_variables @@ -1709,9 +1863,15 @@ def serialize(self, suffix: str = "") -> dict: raise NotImplementedError( f"Not implemented in class {self.__class__.__name__}" ) - if self.stripped_type_embedding: + if self.stripped_type_embedding and type(self) is not DescrptDPA1Compat: + # only DescrptDPA1Compat can serialize when tebd_input_mode=='strip' + raise NotImplementedError( + "serialization is unsupported by the native model when tebd_input_mode=='strip'" + ) + # todo support serialization when tebd_input_mode=='strip' and type_one_side is True + if self.stripped_type_embedding and self.type_one_side: raise NotImplementedError( - "stripped_type_embedding is unsupported by the native model" + "serialization is unsupported when tebd_input_mode=='strip' and type_one_side is True" ) if (self.original_sel != self.sel_a).any(): raise NotImplementedError( @@ -1724,7 +1884,7 @@ def serialize(self, suffix: str = "") -> dict: assert self.davg is not None assert self.dstd is not None - return { + data = { "@class": "Descriptor", "type": "se_atten", "@version": 1, @@ -1742,6 +1902,7 @@ def serialize(self, suffix: str = "") -> dict: "activation_function": self.activation_function_name, "resnet_dt": self.filter_resnet_dt, "smooth_type_embedding": self.smooth, + "tebd_input_mode": self.tebd_input_mode, "trainable_ln": self.trainable_ln, "ln_eps": self.ln_eps, "precision": self.filter_precision.name, @@ -1781,6 +1942,27 @@ def serialize(self, suffix: str = "") -> dict: "type_one_side": self.type_one_side, "spin": self.spin, } + if self.tebd_input_mode in ["strip"]: + assert ( + type(self) is DescrptDPA1Compat + ), "only DescrptDPA1Compat can serialize when tebd_input_mode=='strip'" + data.update( + { + "embeddings_strip": self.serialize_network_strip( + ntypes=self.ntypes, + ndim=0, + in_dim=2 + * self.tebd_dim, # only DescrptDPA1Compat has this attribute + neuron=self.filter_neuron, + activation_function=self.activation_function_name, + resnet_dt=self.filter_resnet_dt, + variables=self.two_side_embeeding_net_variables, + suffix=suffix, + type_one_side=self.type_one_side, + ) + } + ) + return data class DescrptDPA1Compat(DescrptSeAtten): @@ -1806,8 +1988,9 @@ class DescrptDPA1Compat(DescrptSeAtten): tebd_dim: int Dimension of the type embedding tebd_input_mode: str - (Only support `concat` to keep consistent with other backend references.) - The way to mix the type embeddings. + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. resnet_dt: bool Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b) @@ -1898,10 +2081,6 @@ def __init__( seed: Optional[int] = None, uniform_seed: bool = False, ) -> None: - if tebd_input_mode != "concat": - raise NotImplementedError( - "Only support tebd_input_mode == `concat` in this version." - ) if not normalize: raise NotImplementedError("Only support normalize == True in this version.") if temperature != 1.0: @@ -1939,14 +2118,13 @@ def __init__( attn_dotr=attn_dotr, attn_mask=attn_mask, multi_task=True, - stripped_type_embedding=False, trainable_ln=trainable_ln, ln_eps=ln_eps, smooth_type_embedding=smooth_type_embedding, + tebd_input_mode=tebd_input_mode, env_protection=env_protection, ) self.tebd_dim = tebd_dim - self.tebd_input_mode = tebd_input_mode self.scaling_factor = scaling_factor self.normalize = normalize self.temperature = temperature @@ -2085,9 +2263,20 @@ def deserialize(cls, data: dict, suffix: str = ""): data.pop("env_mat") variables = data.pop("@variables") type_embedding = data.pop("type_embedding") + tebd_input_mode = data["tebd_input_mode"] + type_one_side = data["type_one_side"] + if tebd_input_mode in ["strip"]: + two_side_embeeding_net_variables = cls.deserialize_network_strip( + data.pop("embeddings_strip"), + suffix=suffix, + type_one_side=type_one_side, + ) + else: + two_side_embeeding_net_variables = None descriptor = cls(**data) descriptor.embedding_net_variables = embedding_net_variables descriptor.attention_layer_variables = attention_layer_variables + descriptor.two_side_embeeding_net_variables = two_side_embeeding_net_variables descriptor.davg = variables["davg"].reshape( descriptor.ntypes, descriptor.ndescrpt ) @@ -2117,7 +2306,6 @@ def serialize(self, suffix: str = "") -> dict: { "type": "dpa1", "tebd_dim": self.tebd_dim, - "tebd_input_mode": self.tebd_input_mode, "scaling_factor": self.scaling_factor, "normalize": self.normalize, "temperature": self.temperature, diff --git a/deepmd/tf/descriptor/se_atten_v2.py b/deepmd/tf/descriptor/se_atten_v2.py index 01c4d93ad8..61e672788e 100644 --- a/deepmd/tf/descriptor/se_atten_v2.py +++ b/deepmd/tf/descriptor/se_atten_v2.py @@ -109,7 +109,7 @@ def __init__( attn_dotr=attn_dotr, attn_mask=attn_mask, multi_task=multi_task, - stripped_type_embedding=True, + tebd_input_mode="strip", smooth_type_embedding=True, **kwargs, ) diff --git a/deepmd/tf/nvnmd/data/data.py b/deepmd/tf/nvnmd/data/data.py index 9e6dd4cc89..7f2c9ef5e9 100644 --- a/deepmd/tf/nvnmd/data/data.py +++ b/deepmd/tf/nvnmd/data/data.py @@ -332,7 +332,7 @@ "descriptor": { "seed": 1, "type": "se_atten", - "stripped_type_embedding": True, + "tebd_input_mode": "strip", "sel": 128, "rcut": 7.0, "rcut_smth": 0.5, diff --git a/deepmd/tf/utils/graph.py b/deepmd/tf/utils/graph.py index 53de9c9ce2..a891506e95 100644 --- a/deepmd/tf/utils/graph.py +++ b/deepmd/tf/utils/graph.py @@ -216,60 +216,70 @@ def get_extra_embedding_net_suffix(type_one_side: bool): return extra_suffix -def get_variables_from_graph_def_as_numpy_array(graph_def: tf.GraphDef, pattern: str): - """Get variables from the given tf.GraphDef object, with numpy array returns. +def get_extra_embedding_net_nodes_from_graph_def( + graph_def: tf.GraphDef, + suffix: str = "", + extra_suffix: str = "", +) -> Dict: + """Get the extra embedding net nodes with the given tf.GraphDef object. Parameters ---------- graph_def The input tf.GraphDef object - pattern : str - The name of variable + suffix : str, optional + The scope suffix + extra_suffix : str + The extra scope suffix Returns ------- - np.ndarray - The numpy array of the variable + Dict + The embedding net nodes within the given tf.GraphDef object """ - node = get_pattern_nodes_from_graph_def(graph_def, pattern)[pattern] - return tf.make_ndarray(node) + embedding_net_pattern_strip = str( + rf"filter_type_(all)/(matrix)_(\d+){extra_suffix}|" + rf"filter_type_(all)/(bias)_(\d+){extra_suffix}|" + rf"filter_type_(all)/(idt)_(\d+){extra_suffix}|" + )[:-1] + if suffix != "": + embedding_net_pattern_strip = ( + embedding_net_pattern_strip.replace("/(idt)", suffix + "/(idt)") + .replace("/(bias)", suffix + "/(bias)") + .replace("/(matrix)", suffix + "/(matrix)") + ) + + embedding_net_nodes_strip = get_pattern_nodes_from_graph_def( + graph_def, embedding_net_pattern_strip + ) + return embedding_net_nodes_strip def get_extra_embedding_net_variables_from_graph_def( - graph_def: tf.GraphDef, suffix: str, extra_suffix: str, layer_size: int -): - """Get extra embedding net variables from the given tf.GraphDef object. - The "extra embedding net" means the embedding net with only type embeddings input, - which occurs in "se_atten_v2" and "se_a_ebd_v2" descriptor. + graph_def: tf.GraphDef, + suffix: str = "", + extra_suffix: str = "", +) -> Dict: + """Get the embedding net variables with the given tf.GraphDef object. Parameters ---------- graph_def The input tf.GraphDef object - suffix : str - The "common" suffix in the descriptor - extra_suffix : str - This value depends on the value of "type_one_side". - It should always be "_one_side_ebd" or "_two_side_ebd" - layer_size : int - The layer size of the embedding net + suffix : str, optional + The suffix of the scope + extra_suffix + The extra scope suffix Returns ------- Dict - The extra embedding net variables within the given tf.GraphDef object + The embedding net variables within the given tf.GraphDef object """ - extra_embedding_net_variables = {} - for i in range(1, layer_size + 1): - matrix_pattern = f"filter_type_all{suffix}/matrix_{i}{extra_suffix}" - extra_embedding_net_variables[matrix_pattern] = ( - get_variables_from_graph_def_as_numpy_array(graph_def, matrix_pattern) - ) - bias_pattern = f"filter_type_all{suffix}/bias_{i}{extra_suffix}" - extra_embedding_net_variables[bias_pattern] = ( - get_variables_from_graph_def_as_numpy_array(graph_def, bias_pattern) - ) - return extra_embedding_net_variables + extra_embedding_net_nodes = get_extra_embedding_net_nodes_from_graph_def( + graph_def, extra_suffix=extra_suffix, suffix=suffix + ) + return convert_tensor_to_ndarray_in_dict(extra_embedding_net_nodes) def get_embedding_net_variables(model_file: str, suffix: str = "") -> Dict: diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 96ec7eb10f..c817536b92 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -424,7 +424,7 @@ def descrpt_se_atten_common_args(): + "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." ) doc_attn = "The length of hidden vectors in attention layers" - doc_attn_layer = "The number of attention layers." + doc_attn_layer = "The number of attention layers. Note that model compression of `se_atten` is only enabled when attn_layer==0 and tebd_input_mode=='strip'" doc_attn_dotr = "Whether to do dot product with the normalized relative coordinates" doc_attn_mask = "Whether to do mask on the diagonal in the attention matrix" @@ -475,7 +475,6 @@ def descrpt_se_atten_common_args(): @descrpt_args_plugin.register("se_atten", alias=["dpa1"]) def descrpt_se_atten_args(): - doc_stripped_type_embedding = "Whether to strip the type embedding into a separated embedding network. Setting it to `False` will fall back to the previous version of `se_atten` which is non-compressible." doc_smooth_type_embedding = f"Whether to use smooth process in attention weights calculation. {doc_only_tf_supported} When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True." doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used" doc_trainable_ln = ( @@ -495,7 +494,21 @@ def descrpt_se_atten_args(): doc_concat_output_tebd = ( "Whether to concat type embedding at the output of the descriptor." ) - doc_deprecated = "This feature will be removed in a future release." + doc_tebd_input_mode = ( + "The input mode of the type embedding. Supported modes are ['concat', 'strip']." + "- 'concat': Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. " + "When `type_one_side` is False, the input is `input_ij = concat([r_ij, tebd_j, tebd_i])`. When `type_one_side` is True, the input is `input_ij = concat([r_ij, tebd_j])`. " + "The output is `out_ij = embeding(input_ij)` for the pair-wise representation of atom i with neighbor j." + "- 'strip': Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. " + f"When `type_one_side` is False, the input is `input_t = concat([tebd_j, tebd_i])`. {doc_only_pt_supported} When `type_one_side` is True, the input is `input_t = tebd_j`. " + "The output is `out_ij = embeding_t(input_t) * embeding_s(r_ij) + embeding_s(r_ij)` for the pair-wise representation of atom i with neighbor j." + ) + doc_stripped_type_embedding = ( + "(Deprecated, kept only for compatibility.) Whether to strip the type embedding into a separate embedding network. " + "Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'. " + "Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'." + "The default value is `None`, which means the `tebd_input_mode` setting will be used instead." + ) return [ *descrpt_se_atten_common_args(), @@ -503,8 +516,8 @@ def descrpt_se_atten_args(): "stripped_type_embedding", bool, optional=True, - default=False, - doc=doc_only_tf_supported + doc_stripped_type_embedding, + default=None, + doc=doc_stripped_type_embedding, ), Argument( "smooth_type_embedding", @@ -534,7 +547,7 @@ def descrpt_se_atten_args(): str, optional=True, default="concat", - doc=doc_only_pt_supported + doc_deprecated, + doc=doc_tebd_input_mode, ), Argument( "scaling_factor", diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index 74e4c9fa78..4d55383891 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -156,7 +156,7 @@ An example of the DPA-1 descriptor is provided as follows We highly recommend using the version 2.0 of the attention-based descriptor `"se_atten_v2"`, which is inherited from `"se_atten"` but with the following parameter modifications: ```json - "stripped_type_embedding": true, + "tebd_input_mode": "strip", "smooth_type_embedding": true, "set_davg_zero": false ``` diff --git a/examples/water/se_atten_dpa1_compat/input.json b/examples/water/se_atten_dpa1_compat/input.json index 90c597e586..3018096ae5 100644 --- a/examples/water/se_atten_dpa1_compat/input.json +++ b/examples/water/se_atten_dpa1_compat/input.json @@ -7,7 +7,7 @@ ], "descriptor": { "type": "se_atten", - "stripped_type_embedding": false, + "tebd_input_mode": "concat", "sel": 120, "rcut_smth": 0.50, "rcut": 6.00, diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index c0ca46c91e..a2d4ca074f 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -40,7 +40,7 @@ @parameterized( (4,), # tebd_dim - ("concat",), # tebd_input_mode + ("concat", "strip"), # tebd_input_mode (True,), # resnet_dt (True, False), # type_one_side (20,), # attn @@ -181,6 +181,7 @@ def skip_tf(self) -> bool: or not normalize or temperature != 1.0 or (excluded_types != [] and attn_layer > 0) + or (type_one_side and tebd_input_mode == "strip") # not consistent yet ) tf_class = DescrptDPA1TF diff --git a/source/tests/pt/model/test_dpa1.py b/source/tests/pt/model/test_dpa1.py index 7567f18593..c1b6f97b26 100644 --- a/source/tests/pt/model/test_dpa1.py +++ b/source/tests/pt/model/test_dpa1.py @@ -39,11 +39,12 @@ def test_consistency( dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) - for idt, prec, sm, to in itertools.product( + for idt, prec, sm, to, tm in itertools.product( [False, True], # resnet_dt ["float64", "float32"], # precision [False, True], # smooth_type_embedding [False, True], # type_one_side + ["concat", "strip"], # tebd_input_mode ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -60,6 +61,7 @@ def test_consistency( resnet_dt=idt, smooth_type_embedding=sm, type_one_side=to, + tebd_input_mode=tm, old_impl=False, ).to(env.DEVICE) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) @@ -98,7 +100,7 @@ def test_consistency( err_msg=err_msg, ) # old impl - if idt is False and prec == "float64" and to is False: + if idt is False and prec == "float64" and to is False and tm == "concat": dd3 = DescrptDPA1( self.rcut, self.rcut_smth, @@ -163,16 +165,21 @@ def test_jit( dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) - for idt, prec, sm, to in itertools.product( - [False, True], - ["float64", "float32"], - [False, True], - [False, True], + for idt, prec, sm, to, tm in itertools.product( + [ + False, + ], # resnet_dt + [ + "float64", + ], # precision + [False, True], # smooth_type_embedding + [False, True], # type_one_side + ["concat", "strip"], # tebd_input_mode ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) err_msg = f"idt={idt} prec={prec}" - # sea new impl + # dpa1 new impl dd0 = DescrptDPA1( self.rcut, self.rcut_smth, @@ -182,6 +189,7 @@ def test_jit( resnet_dt=idt, smooth_type_embedding=sm, type_one_side=to, + tebd_input_mode=tm, old_impl=False, ) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) diff --git a/source/tests/tf/test_data_large_batch.py b/source/tests/tf/test_data_large_batch.py index dad6bbf252..1232f8b1db 100644 --- a/source/tests/tf/test_data_large_batch.py +++ b/source/tests/tf/test_data_large_batch.py @@ -309,7 +309,7 @@ def test_stripped_data_mixed_type(self): jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True) jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out() @@ -507,7 +507,7 @@ def test_compressible_data_mixed_type(self): jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 0 descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True) jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() diff --git a/source/tests/tf/test_descrpt_se_atten.py b/source/tests/tf/test_descrpt_se_atten.py index 7a1bfd18f6..da5eb05650 100644 --- a/source/tests/tf/test_descrpt_se_atten.py +++ b/source/tests/tf/test_descrpt_se_atten.py @@ -421,7 +421,7 @@ def test_stripped_type_embedding_descriptor_two_sides(self): "resnet_dt": False, "seed": 1, } - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" # init models typeebd = TypeEmbedNet( @@ -588,7 +588,7 @@ def test_compressible_descriptor_two_sides(self): jdata["model"]["descriptor"]["neuron"] = [5, 5, 5] jdata["model"]["descriptor"]["axis_neuron"] = 2 jdata["model"]["descriptor"]["attn_layer"] = 0 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" typeebd_param = { "neuron": [5], "resnet_dt": False, diff --git a/source/tests/tf/test_finetune_se_atten.py b/source/tests/tf/test_finetune_se_atten.py index 40fc5b68a3..ebb858b0bb 100644 --- a/source/tests/tf/test_finetune_se_atten.py +++ b/source/tests/tf/test_finetune_se_atten.py @@ -146,15 +146,15 @@ def setUpClass(cls) -> None: if not parse_version(tf.__version__) < parse_version("1.15"): def previous_se_atten(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = False + jdata["model"]["descriptor"]["tebd_input_mode"] = "concat" jdata["model"]["descriptor"]["attn_layer"] = 2 def stripped_model(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 2 def compressible_model(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 0 models = [previous_se_atten, stripped_model, compressible_model] diff --git a/source/tests/tf/test_init_frz_model_se_atten.py b/source/tests/tf/test_init_frz_model_se_atten.py index a114deffc8..25f629511c 100644 --- a/source/tests/tf/test_init_frz_model_se_atten.py +++ b/source/tests/tf/test_init_frz_model_se_atten.py @@ -136,15 +136,15 @@ def _init_models(model_setup, i): if not parse_version(tf.__version__) < parse_version("1.15"): def previous_se_atten(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = False + jdata["model"]["descriptor"]["tebd_input_mode"] = "concat" jdata["model"]["descriptor"]["attn_layer"] = 2 def stripped_model(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 2 def compressible_model(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 0 diff --git a/source/tests/tf/test_model_compression_se_atten.py b/source/tests/tf/test_model_compression_se_atten.py index 1ac82446c6..246990d207 100644 --- a/source/tests/tf/test_model_compression_se_atten.py +++ b/source/tests/tf/test_model_compression_se_atten.py @@ -79,7 +79,7 @@ def _init_models(): jdata["model"]["descriptor"] = {} jdata["model"]["descriptor"]["type"] = "se_atten" jdata["model"]["descriptor"]["precision"] = tests[i]["se_atten precision"] - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["sel"] = 120 jdata["model"]["descriptor"]["attn_layer"] = 0 jdata["model"]["descriptor"]["smooth_type_embedding"] = tests[i][ @@ -128,7 +128,7 @@ def _init_models_exclude_types(): jdata["model"]["descriptor"]["type"] = "se_atten" jdata["model"]["descriptor"]["exclude_types"] = [[0, 1]] jdata["model"]["descriptor"]["precision"] = tests[i]["se_atten precision"] - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["sel"] = 120 jdata["model"]["descriptor"]["attn_layer"] = 0 jdata["model"]["type_embedding"] = {} diff --git a/source/tests/tf/test_model_se_atten.py b/source/tests/tf/test_model_se_atten.py index d75dc0cfff..a4b6575ecb 100644 --- a/source/tests/tf/test_model_se_atten.py +++ b/source/tests/tf/test_model_se_atten.py @@ -293,7 +293,7 @@ def test_compressible_model(self): jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 0 descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True) jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() @@ -461,7 +461,7 @@ def test_compressible_exclude_types(self): # successful descrpt = DescrptSeAtten(ntypes=ntypes, **jdata["model"]["descriptor"]) typeebd_param = jdata["model"]["type_embedding"] - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 0 typeebd = TypeEmbedNet( ntypes=descrpt.get_ntypes(), @@ -524,7 +524,7 @@ def test_stripped_type_embedding_model(self): jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 2 descrpt = DescrptSeAtten(**jdata["model"]["descriptor"], uniform_seed=True) jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() @@ -695,7 +695,7 @@ def test_stripped_type_embedding_exclude_types(self): # successful descrpt = DescrptSeAtten(ntypes=ntypes, **jdata["model"]["descriptor"]) typeebd_param = jdata["model"]["type_embedding"] - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["attn_layer"] = 2 typeebd = TypeEmbedNet( ntypes=descrpt.get_ntypes(), @@ -763,7 +763,7 @@ def test_smoothness_of_stripped_type_embedding_smooth_model(self): jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["smooth_type_embedding"] = True jdata["model"]["descriptor"]["attn_layer"] = 1 jdata["model"]["descriptor"]["rcut"] = 6.0 @@ -909,7 +909,7 @@ def test_smoothness_of_stripped_type_embedding_smooth_model_excluded_types(self) jdata["model"]["descriptor"].pop("type", None) jdata["model"]["descriptor"]["ntypes"] = 2 - jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["tebd_input_mode"] = "strip" jdata["model"]["descriptor"]["smooth_type_embedding"] = True jdata["model"]["descriptor"]["attn_layer"] = 1 jdata["model"]["descriptor"]["rcut"] = 6.0