Skip to content

Commit

Permalink
feat: apply descriptor exclude_types to env mat stat (deepmodeling#3625)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit 87d293a)
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz and pre-commit-ci[bot] committed Jul 2, 2024
1 parent dbf5d4a commit 478fbf3
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 4 deletions.
3 changes: 2 additions & 1 deletion deepmd/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Dict,
List,
Optional,
Set,
Tuple,
)

Expand Down Expand Up @@ -393,7 +394,7 @@ def pass_tensors_from_frz_model(

def build_type_exclude_mask(
self,
exclude_types: List[Tuple[int, int]],
exclude_types: Set[Tuple[int, int]],
ntypes: int,
sel: List[int],
ndescrpt: int,
Expand Down
12 changes: 12 additions & 0 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,18 @@ def __init__(
sel_a=self.sel_a,
sel_r=self.sel_r,
)
if len(self.exclude_types):
# exclude types applied to data stat
mask = self.build_type_exclude_mask(
self.exclude_types,
self.ntypes,
self.sel_a,
self.ndescrpt,
# for data stat, nloc == nall
self.place_holders["type"],
tf.size(self.place_holders["type"]),
)
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)
self.original_sel = None
self.multi_task = multi_task
Expand Down
20 changes: 17 additions & 3 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (
List,
Optional,
Set,
Tuple,
)

Expand Down Expand Up @@ -250,6 +251,19 @@ def __init__(
sel_a=self.sel_all_a,
sel_r=self.sel_all_r,
)
if len(self.exclude_types):
# exclude types applied to data stat
mask = self.build_type_exclude_mask_mixed(
self.exclude_types,
self.ntypes,
self.sel_a,
self.ndescrpt,
# for data stat, nloc == nall
self.place_holders["type"],
tf.size(self.place_holders["type"]),
self.nei_type_vec_t, # extra input for atten
)
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)

def compute_input_stats(
Expand Down Expand Up @@ -640,7 +654,7 @@ def _pass_filter(
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
type_i = -1
if len(self.exclude_types):
mask = self.build_type_exclude_mask(
mask = self.build_type_exclude_mask_mixed(
self.exclude_types,
self.ntypes,
self.sel_a,
Expand Down Expand Up @@ -1335,9 +1349,9 @@ def init_variables(
)
)

def build_type_exclude_mask(
def build_type_exclude_mask_mixed(
self,
exclude_types: List[Tuple[int, int]],
exclude_types: Set[Tuple[int, int]],
ntypes: int,
sel: List[int],
ndescrpt: int,
Expand Down
12 changes: 12 additions & 0 deletions deepmd/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,18 @@ def __init__(
rcut_smth=self.rcut_smth,
sel=self.sel_r,
)
if len(self.exclude_types):
# exclude types applied to data stat
mask = self.build_type_exclude_mask(
self.exclude_types,
self.ntypes,
self.sel_r,
self.ndescrpt,
# for data stat, nloc == nall
self.place_holders["type"],
tf.size(self.place_holders["type"]),
)
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))
self.sub_sess = tf.Session(
graph=sub_graph, config=default_tf_session_config
)
Expand Down

0 comments on commit 478fbf3

Please sign in to comment.