Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: apply descriptor exclude_types to env mat stat #3625

Merged
merged 8 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deepmd/pt/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@
radial_only,
protection=self.descriptor.env_protection,
)
# apply excluded_types
exclude_mask = self.descriptor.emask(nlist, extended_atype)
env_mat *= exclude_mask.unsqueeze(-1)

Check warning on line 151 in deepmd/pt/utils/env_mat_stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/env_mat_stat.py#L150-L151

Added lines #L150 - L151 were not covered by tests
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
# reshape to nframes * nloc at the atom level,
# so nframes/mixed_type do not matter
env_mat = env_mat.view(
Expand Down
3 changes: 2 additions & 1 deletion deepmd/tf/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Dict,
List,
Optional,
Set,
Tuple,
)

Expand Down Expand Up @@ -357,7 +358,7 @@ def pass_tensors_from_frz_model(

def build_type_exclude_mask(
self,
exclude_types: List[Tuple[int, int]],
exclude_types: Set[Tuple[int, int]],
ntypes: int,
sel: List[int],
ndescrpt: int,
Expand Down
12 changes: 12 additions & 0 deletions deepmd/tf/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,18 @@
sel_a=self.sel_a,
sel_r=self.sel_r,
)
if len(self.exclude_types):

Check warning on line 291 in deepmd/tf/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_a.py#L291

Added line #L291 was not covered by tests
# exclude types applied to data stat
mask = self.build_type_exclude_mask(

Check warning on line 293 in deepmd/tf/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_a.py#L293

Added line #L293 was not covered by tests
self.exclude_types,
self.ntypes,
self.sel_a,
self.ndescrpt,
# for data stat, nloc == nall
self.place_holders["type"],
tf.size(self.place_holders["type"]),
)
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))

Check warning on line 302 in deepmd/tf/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_a.py#L302

Added line #L302 was not covered by tests
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)
self.original_sel = None
self.multi_task = multi_task
Expand Down
20 changes: 17 additions & 3 deletions deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (
List,
Optional,
Set,
Tuple,
)

Expand Down Expand Up @@ -282,6 +283,19 @@
sel_a=self.sel_all_a,
sel_r=self.sel_all_r,
)
if len(self.exclude_types):

Check warning on line 286 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L286

Added line #L286 was not covered by tests
# exclude types applied to data stat
mask = self.build_type_exclude_mask_mixed(

Check warning on line 288 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L288

Added line #L288 was not covered by tests
self.exclude_types,
self.ntypes,
self.sel_a,
self.ndescrpt,
# for data stat, nloc == nall
self.place_holders["type"],
tf.size(self.place_holders["type"]),
self.nei_type_vec_t, # extra input for atten
)
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))

Check warning on line 298 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L298

Added line #L298 was not covered by tests
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)

def compute_input_stats(
Expand Down Expand Up @@ -672,7 +686,7 @@
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
type_i = -1
if len(self.exclude_types):
mask = self.build_type_exclude_mask(
mask = self.build_type_exclude_mask_mixed(

Check warning on line 689 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L689

Added line #L689 was not covered by tests
self.exclude_types,
self.ntypes,
self.sel_a,
Expand Down Expand Up @@ -1357,9 +1371,9 @@
)
)

def build_type_exclude_mask(
def build_type_exclude_mask_mixed(
self,
exclude_types: List[Tuple[int, int]],
exclude_types: Set[Tuple[int, int]],
ntypes: int,
sel: List[int],
ndescrpt: int,
Expand Down
12 changes: 12 additions & 0 deletions deepmd/tf/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,18 @@
rcut_smth=self.rcut_smth,
sel=self.sel_r,
)
if len(self.exclude_types):

Check warning on line 199 in deepmd/tf/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_r.py#L199

Added line #L199 was not covered by tests
# exclude types applied to data stat
mask = self.build_type_exclude_mask(

Check warning on line 201 in deepmd/tf/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_r.py#L201

Added line #L201 was not covered by tests
self.exclude_types,
self.ntypes,
self.sel_a,
njzjz marked this conversation as resolved.
Show resolved Hide resolved
self.ndescrpt,
# for data stat, nloc == nall
self.place_holders["type"],
tf.size(self.place_holders["type"]),
)
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))

Check warning on line 210 in deepmd/tf/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_r.py#L210

Added line #L210 was not covered by tests
self.sub_sess = tf.Session(
graph=sub_graph, config=default_tf_session_config
)
Expand Down
38 changes: 38 additions & 0 deletions source/tests/pt/test_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,5 +325,43 @@ def tf_compute_input_stats(self):
)


class TestExcludeTypes(DatasetTest, unittest.TestCase):
def setup_data(self):
original_data = str(Path(__file__).parent / "water/data/data_0")
picked_data = str(Path(__file__).parent / "picked_data_for_test_stat")
dpdata.LabeledSystem(original_data, fmt="deepmd/npy")[:2].to_deepmd_npy(
picked_data
)
self.mixed_type = False
return picked_data

def setup_tf(self):
return DescrptSeA_tf(
rcut=self.rcut,
rcut_smth=self.rcut_smth,
sel=self.sel,
neuron=self.filter_neuron,
axis_neuron=self.axis_neuron,
exclude_types=[[0, 0], [1, 1]],
)

def setup_pt(self):
return DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
self.filter_neuron,
self.axis_neuron,
exclude_types=[[0, 0], [1, 1]],
).sea # get the block who has stat as private vars

def tf_compute_input_stats(self):
coord = self.dp_merged["coord"]
atype = self.dp_merged["type"]
natoms = self.dp_merged["natoms_vec"]
box = self.dp_merged["box"]
self.dp_d.compute_input_stats(coord, box, atype, natoms, self.dp_mesh, {})


if __name__ == "__main__":
unittest.main()
Loading