diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index 47e17e9eaa..a853b1722a 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -146,6 +146,9 @@ def iter( radial_only, protection=self.descriptor.env_protection, ) + # apply excluded_types + exclude_mask = self.descriptor.emask(nlist, extended_atype) + env_mat *= exclude_mask.unsqueeze(-1) # reshape to nframes * nloc at the atom level, # so nframes/mixed_type do not matter env_mat = env_mat.view( diff --git a/deepmd/tf/descriptor/descriptor.py b/deepmd/tf/descriptor/descriptor.py index 82b09c95fb..6a4ed5f354 100644 --- a/deepmd/tf/descriptor/descriptor.py +++ b/deepmd/tf/descriptor/descriptor.py @@ -7,6 +7,7 @@ Dict, List, Optional, + Set, Tuple, ) @@ -357,7 +358,7 @@ def pass_tensors_from_frz_model( def build_type_exclude_mask( self, - exclude_types: List[Tuple[int, int]], + exclude_types: Set[Tuple[int, int]], ntypes: int, sel: List[int], ndescrpt: int, diff --git a/deepmd/tf/descriptor/se_a.py b/deepmd/tf/descriptor/se_a.py index 7b22b3efd2..4f7897e76c 100644 --- a/deepmd/tf/descriptor/se_a.py +++ b/deepmd/tf/descriptor/se_a.py @@ -288,6 +288,18 @@ def __init__( sel_a=self.sel_a, sel_r=self.sel_r, ) + if len(self.exclude_types): + # exclude types applied to data stat + mask = self.build_type_exclude_mask( + self.exclude_types, + self.ntypes, + self.sel_a, + self.ndescrpt, + # for data stat, nloc == nall + self.place_holders["type"], + tf.size(self.place_holders["type"]), + ) + self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt)) self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config) self.original_sel = None self.multi_task = multi_task diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 82184dec02..0b26d83732 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -4,6 +4,7 @@ from typing import ( List, Optional, + Set, Tuple, ) @@ -282,6 +283,19 @@ def __init__( sel_a=self.sel_all_a, sel_r=self.sel_all_r, ) + if len(self.exclude_types): + # exclude types applied to data stat + mask = self.build_type_exclude_mask_mixed( + self.exclude_types, + self.ntypes, + self.sel_a, + self.ndescrpt, + # for data stat, nloc == nall + self.place_holders["type"], + tf.size(self.place_holders["type"]), + self.nei_type_vec_t, # extra input for atten + ) + self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt)) self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config) def compute_input_stats( @@ -672,7 +686,7 @@ def _pass_filter( inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) type_i = -1 if len(self.exclude_types): - mask = self.build_type_exclude_mask( + mask = self.build_type_exclude_mask_mixed( self.exclude_types, self.ntypes, self.sel_a, @@ -1367,9 +1381,9 @@ def init_variables( ) ) - def build_type_exclude_mask( + def build_type_exclude_mask_mixed( self, - exclude_types: List[Tuple[int, int]], + exclude_types: Set[Tuple[int, int]], ntypes: int, sel: List[int], ndescrpt: int, diff --git a/deepmd/tf/descriptor/se_r.py b/deepmd/tf/descriptor/se_r.py index 1443914aab..aef40b74bf 100644 --- a/deepmd/tf/descriptor/se_r.py +++ b/deepmd/tf/descriptor/se_r.py @@ -196,6 +196,18 @@ def __init__( rcut_smth=self.rcut_smth, sel=self.sel_r, ) + if len(self.exclude_types): + # exclude types applied to data stat + mask = self.build_type_exclude_mask( + self.exclude_types, + self.ntypes, + self.sel_r, + self.ndescrpt, + # for data stat, nloc == nall + self.place_holders["type"], + tf.size(self.place_holders["type"]), + ) + self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt)) self.sub_sess = tf.Session( graph=sub_graph, config=default_tf_session_config ) diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py index 2362821dfa..b7e3f6e2d3 100644 --- a/source/tests/pt/test_stat.py +++ b/source/tests/pt/test_stat.py @@ -337,6 +337,44 @@ def tf_compute_input_stats(self): ) +class TestExcludeTypes(DatasetTest, unittest.TestCase): + def setup_data(self): + original_data = str(Path(__file__).parent / "water/data/data_0") + picked_data = str(Path(__file__).parent / "picked_data_for_test_stat") + dpdata.LabeledSystem(original_data, fmt="deepmd/npy")[:2].to_deepmd_npy( + picked_data + ) + self.mixed_type = False + return picked_data + + def setup_tf(self): + return DescrptSeA_tf( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + sel=self.sel, + neuron=self.filter_neuron, + axis_neuron=self.axis_neuron, + exclude_types=[[0, 0], [1, 1]], + ) + + def setup_pt(self): + return DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + self.filter_neuron, + self.axis_neuron, + exclude_types=[[0, 0], [1, 1]], + ).sea # get the block who has stat as private vars + + def tf_compute_input_stats(self): + coord = self.dp_merged["coord"] + atype = self.dp_merged["type"] + natoms = self.dp_merged["natoms_vec"] + box = self.dp_merged["box"] + self.dp_d.compute_input_stats(coord, box, atype, natoms, self.dp_mesh, {}) + + class TestOutputStat(unittest.TestCase): def setUp(self): self.data_file = [str(Path(__file__).parent / "water/data/data_0")]