From de3b048bd6a943f24fb33c43b60331f7beca1ef5 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 6 Jun 2024 17:29:10 -0400 Subject: [PATCH 1/2] fix(tf): throw RuntimeError for se_a + type_embedding Fix #3541. Signed-off-by: Jinzhe Zeng --- deepmd/tf/descriptor/se_a.py | 13 +++++++++---- source/tests/tf/test_model_se_a_type.py | 3 +++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/deepmd/tf/descriptor/se_a.py b/deepmd/tf/descriptor/se_a.py index 51c79f36af..1175842a5a 100644 --- a/deepmd/tf/descriptor/se_a.py +++ b/deepmd/tf/descriptor/se_a.py @@ -301,6 +301,7 @@ def __init__( self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt)) self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config) self.original_sel = None + self.use_tebd: Optional[bool] = None def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -746,8 +747,10 @@ def _pass_filter( ): if input_dict is not None: type_embedding = input_dict.get("type_embedding", None) + self.use_tebd = True else: type_embedding = None + self.use_tebd = False if self.stripped_type_embedding and type_embedding is None: raise RuntimeError("type_embedding is required for se_a_tebd_v2 model.") start_index = 0 @@ -1406,7 +1409,7 @@ def serialize(self, suffix: str = "") -> dict: raise NotImplementedError( "Serialization is unsupported when tebd_input_mode is set to 'strip'" ) - if (self.original_sel != self.sel_a).any(): + if self.original_sel is not None and (self.original_sel != self.sel_a).any(): raise NotImplementedError( "Adjusting sel is unsupported by the native model" ) @@ -1416,9 +1419,11 @@ def serialize(self, suffix: str = "") -> dict: raise NotImplementedError("spin is unsupported") assert self.davg is not None assert self.dstd is not None - # TODO: tf: handle type embedding in DescrptSeA.serialize - # not sure how to handle type embedding - type embedding is not a model parameter, - # but instead a part of the input data. Maybe the interface should be refactored... + assert self.use_tebd is not None + if self.use_tebd: + raise RuntimeError( + "Serialization is unsupported when type_embedding is used." + ) return { "@class": "Descriptor", diff --git a/source/tests/tf/test_model_se_a_type.py b/source/tests/tf/test_model_se_a_type.py index b0f5da6b7e..e38afc0fb4 100644 --- a/source/tests/tf/test_model_se_a_type.py +++ b/source/tests/tf/test_model_se_a_type.py @@ -179,3 +179,6 @@ def test_model(self): np.testing.assert_almost_equal(e, refe, places) np.testing.assert_almost_equal(f, reff, places) np.testing.assert_almost_equal(v, refv, places) + + with self.assertRaises(RuntimeError): + descrpt.serialize(suffix="se_a_type") From 81eaea31eb84221b627e4cc6846ec1aa5bcfae73 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 6 Jun 2024 17:51:05 -0400 Subject: [PATCH 2/2] fix the situation when input_dict is not empty Signed-off-by: Jinzhe Zeng --- deepmd/tf/descriptor/se_a.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/tf/descriptor/se_a.py b/deepmd/tf/descriptor/se_a.py index 1175842a5a..108e486da7 100644 --- a/deepmd/tf/descriptor/se_a.py +++ b/deepmd/tf/descriptor/se_a.py @@ -301,7 +301,8 @@ def __init__( self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt)) self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config) self.original_sel = None - self.use_tebd: Optional[bool] = None + # Whether type embedding is used + self.use_tebd: bool = False def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -747,10 +748,10 @@ def _pass_filter( ): if input_dict is not None: type_embedding = input_dict.get("type_embedding", None) - self.use_tebd = True + if type_embedding is not None: + self.use_tebd = True else: type_embedding = None - self.use_tebd = False if self.stripped_type_embedding and type_embedding is None: raise RuntimeError("type_embedding is required for se_a_tebd_v2 model.") start_index = 0 @@ -1419,7 +1420,6 @@ def serialize(self, suffix: str = "") -> dict: raise NotImplementedError("spin is unsupported") assert self.davg is not None assert self.dstd is not None - assert self.use_tebd is not None if self.use_tebd: raise RuntimeError( "Serialization is unsupported when type_embedding is used."