Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt/dpmodel): support type_one_side in se_e2_a #3339

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import itertools

import numpy as np

from deepmd.utils.path import (
Expand Down Expand Up @@ -144,8 +146,6 @@ def __init__(
seed: Optional[int] = None,
) -> None:
## seed, uniform_seed, multi_task, not included.
if not type_one_side:
raise NotImplementedError("type_one_side == False not implemented")
if spin is not None:
raise NotImplementedError("spin is not implemented")

Expand All @@ -171,10 +171,10 @@ def __init__(
ndim=(1 if self.type_one_side else 2),
network_type="embedding_network",
)
if not self.type_one_side:
raise NotImplementedError("type_one_side == False not implemented")
for ii in range(self.ntypes):
self.embeddings[(ii,)] = EmbeddingNet(
for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
):
self.embeddings[embedding_idx] = EmbeddingNet(
in_dim,
self.neuron,
self.activation_function,
Expand Down Expand Up @@ -241,12 +241,12 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
def cal_g(
self,
ss,
ll,
embedding_idx,
):
nf, nloc, nnei = ss.shape[0:3]
ss = ss.reshape(nf, nloc, nnei, 1)
# nf x nloc x nnei x ng
gg = self.embeddings[(ll,)].call(ss)
nf_times_nloc, nnei = ss.shape[0:2]
ss = ss.reshape(nf_times_nloc, nnei, 1)
# (nf x nloc) x nnei x ng
gg = self.embeddings[embedding_idx].call(ss)
return gg

def call(
Expand Down Expand Up @@ -292,16 +292,30 @@ def call(
sec = np.append([0], np.cumsum(self.sel))

ng = self.neuron[-1]
gr = np.zeros([nf, nloc, ng, 4])
gr = np.zeros([nf * nloc, ng, 4])
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
for tt in range(self.ntypes):
mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]]
tr = rr[:, :, sec[tt] : sec[tt + 1], :]
tr = tr * mm[:, :, :, None]
# merge nf and nloc axis, so for type_one_side == False,
# we don't require atype is the same in all frames
exclude_mask = exclude_mask.reshape(nf * nloc, nnei)
rr = rr.reshape(nf * nloc, nnei, 4)

for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
):
if self.type_one_side:
(tt,) = embedding_idx
ti_mask = np.s_[:]
else:
ti, tt = embedding_idx
ti_mask = atype_ext[:, :nloc].ravel() == ti
mm = exclude_mask[ti_mask, sec[tt] : sec[tt + 1]]
tr = rr[ti_mask, sec[tt] : sec[tt + 1], :]
tr = tr * mm[:, :, None]
ss = tr[..., 0:1]
gg = self.cal_g(ss, tt)
# nf x nloc x ng x 4
gr += np.einsum("flni,flnj->flij", gg, tr)
gg = self.cal_g(ss, embedding_idx)
gr_tmp = np.einsum("lni,lnj->lij", gg, tr)
gr[ti_mask] += gr_tmp
gr = gr.reshape(nf, nloc, ng, 4)
# nf x nloc x ng x 4
gr /= self.nnei
gr1 = gr[:, :, : self.axis_neuron, :]
Expand All @@ -313,6 +327,17 @@ def call(

def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
if not self.type_one_side and self.exclude_types:
for embedding_idx in itertools.product(range(self.ntypes), repeat=2):
# not actually used; to match serilization data from TF to pass the test
if embedding_idx in self.emask:
for ilayer in range(len(self.neuron)):
layer = self.embeddings[embedding_idx][ilayer]
layer.w.fill(0.0)
layer.b.fill(0.0)
if layer.idt is not None:
layer.idt.fill(0.0)

wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
return {
"@class": "Descriptor",
"type": "se_e2_a",
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,6 @@ def build_type_exclude_mask(
type_ij = type_ij.reshape(nf, nloc * nnei)
mask = self.type_mask[type_ij].reshape(nf, nloc, nnei)
return mask

def __contains__(self, item):
return item in self.exclude_types
36 changes: 27 additions & 9 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import itertools
from typing import (
ClassVar,
Dict,
Expand Down Expand Up @@ -67,6 +68,7 @@
resnet_dt: bool = False,
exclude_types: List[Tuple[int, int]] = [],
old_impl: bool = False,
type_one_side: bool = True,
**kwargs,
):
super().__init__()
Expand All @@ -82,6 +84,7 @@
resnet_dt=resnet_dt,
exclude_types=exclude_types,
old_impl=old_impl,
type_one_side=type_one_side,
**kwargs,
)

Expand Down Expand Up @@ -214,7 +217,7 @@
},
## to be updated when the options are supported.
"trainable": True,
"type_one_side": True,
"type_one_side": obj.type_one_side,
"spin": None,
}

Expand Down Expand Up @@ -255,6 +258,7 @@
resnet_dt: bool = False,
exclude_types: List[Tuple[int, int]] = [],
old_impl: bool = False,
type_one_side: bool = True,
**kwargs,
):
"""Construct an embedding net of type `se_a`.
Expand All @@ -281,6 +285,7 @@
self.exclude_types = exclude_types
self.ntypes = len(sel)
self.emask = PairExcludeMask(len(sel), exclude_types=exclude_types)
self.type_one_side = type_one_side

self.sel = sel
self.sec = torch.tensor(
Expand All @@ -299,6 +304,10 @@
self.filter_layers = None

if self.old_impl:
if not self.type_one_side:
raise ValueError(

Check warning on line 308 in deepmd/pt/model/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_a.py#L308

Added line #L308 was not covered by tests
"The old implementation does not support type_one_side=False."
)
filter_layers = []
# TODO: remove
start_index = 0
Expand All @@ -308,12 +317,12 @@
start_index += sel[type_i]
self.filter_layers_old = torch.nn.ModuleList(filter_layers)
else:
ndim = 1 if self.type_one_side else 2
filter_layers = NetworkCollection(
ndim=1, ntypes=len(sel), network_type="embedding_network"
ndim=ndim, ntypes=len(sel), network_type="embedding_network"
)
# TODO: ndim=2 if type_one_side=False
for ii in range(self.ntypes):
filter_layers[(ii,)] = EmbeddingNet(
for embedding_idx in itertools.product(range(self.ntypes), repeat=ndim):
filter_layers[embedding_idx] = EmbeddingNet(
1,
self.filter_neuron,
activation_function=self.activation_function,
Expand Down Expand Up @@ -471,18 +480,27 @@
)
# nfnl x nnei
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1)
for ii, ll in enumerate(self.filter_layers.networks):
for embedding_idx, ll in enumerate(self.filter_layers.networks):
if self.type_one_side:
ii = embedding_idx
# torch.jit is not happy with slice(None)
ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device)
else:
# ti: center atom type, ii: neighbor type...
ii = embedding_idx // self.ntypes
ti = embedding_idx % self.ntypes
ti_mask = atype.ravel().eq(ti)
# nfnl x nt
mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]]
mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]]
# nfnl x nt x 4
rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :]
rr = dmatrix[ti_mask, self.sec[ii] : self.sec[ii + 1], :]
rr = rr * mm[:, :, None]
ss = rr[:, :, :1]
# nfnl x nt x ng
gg = ll.forward(ss)
# nfnl x 4 x ng
gr = torch.matmul(rr.permute(0, 2, 1), gg)
xyz_scatter += gr
xyz_scatter[ti_mask] += gr

xyz_scatter /= self.nnei
xyz_scatter_1 = xyz_scatter.permute(0, 2, 1)
Expand Down
15 changes: 13 additions & 2 deletions source/tests/consistent/descriptor/test_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def skip_pt(self) -> bool:
excluded_types,
precision,
) = self.param
return not type_one_side or CommonTest.skip_pt
return CommonTest.skip_pt

@property
def skip_dp(self) -> bool:
Expand All @@ -81,7 +81,7 @@ def skip_dp(self) -> bool:
excluded_types,
precision,
) = self.param
return not type_one_side or CommonTest.skip_dp
return CommonTest.skip_dp

tf_class = DescrptSeATF
dp_class = DescrptSeADP
Expand Down Expand Up @@ -121,6 +121,17 @@ def setUp(self):
dtype=GLOBAL_NP_FLOAT_PRECISION,
)
self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)
# TF se_e2_a type_one_side=False requires atype sorted
(
resnet_dt,
type_one_side,
excluded_types,
precision,
) = self.param
if not type_one_side:
idx = np.argsort(self.atype)
self.atype = self.atype[idx]
self.coords = self.coords.reshape(-1, 3)[idx].ravel()

def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]:
return self.build_tf_descriptor(
Expand Down