Skip to content

Commit

Permalink
feat(pt/dpmodel): support type_one_side in se_e2_a (#3339)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Feb 27, 2024
1 parent a3f4a67 commit 4f70073
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 39 deletions.
58 changes: 39 additions & 19 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import itertools

import numpy as np

from deepmd.utils.path import (
Expand Down Expand Up @@ -144,8 +146,6 @@ def __init__(
seed: Optional[int] = None,
) -> None:
## seed, uniform_seed, multi_task, not included.
if not type_one_side:
raise NotImplementedError("type_one_side == False not implemented")
if spin is not None:
raise NotImplementedError("spin is not implemented")

Expand All @@ -171,10 +171,10 @@ def __init__(
ndim=(1 if self.type_one_side else 2),
network_type="embedding_network",
)
if not self.type_one_side:
raise NotImplementedError("type_one_side == False not implemented")
for ii in range(self.ntypes):
self.embeddings[(ii,)] = EmbeddingNet(
for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
):
self.embeddings[embedding_idx] = EmbeddingNet(
in_dim,
self.neuron,
self.activation_function,
Expand Down Expand Up @@ -241,12 +241,12 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
def cal_g(
self,
ss,
ll,
embedding_idx,
):
nf, nloc, nnei = ss.shape[0:3]
ss = ss.reshape(nf, nloc, nnei, 1)
# nf x nloc x nnei x ng
gg = self.embeddings[(ll,)].call(ss)
nf_times_nloc, nnei = ss.shape[0:2]
ss = ss.reshape(nf_times_nloc, nnei, 1)
# (nf x nloc) x nnei x ng
gg = self.embeddings[embedding_idx].call(ss)
return gg

def call(
Expand Down Expand Up @@ -292,16 +292,30 @@ def call(
sec = np.append([0], np.cumsum(self.sel))

ng = self.neuron[-1]
gr = np.zeros([nf, nloc, ng, 4])
gr = np.zeros([nf * nloc, ng, 4])
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
for tt in range(self.ntypes):
mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]]
tr = rr[:, :, sec[tt] : sec[tt + 1], :]
tr = tr * mm[:, :, :, None]
# merge nf and nloc axis, so for type_one_side == False,
# we don't require atype is the same in all frames
exclude_mask = exclude_mask.reshape(nf * nloc, nnei)
rr = rr.reshape(nf * nloc, nnei, 4)

for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
):
if self.type_one_side:
(tt,) = embedding_idx
ti_mask = np.s_[:]
else:
ti, tt = embedding_idx
ti_mask = atype_ext[:, :nloc].ravel() == ti
mm = exclude_mask[ti_mask, sec[tt] : sec[tt + 1]]
tr = rr[ti_mask, sec[tt] : sec[tt + 1], :]
tr = tr * mm[:, :, None]
ss = tr[..., 0:1]
gg = self.cal_g(ss, tt)
# nf x nloc x ng x 4
gr += np.einsum("flni,flnj->flij", gg, tr)
gg = self.cal_g(ss, embedding_idx)
gr_tmp = np.einsum("lni,lnj->lij", gg, tr)
gr[ti_mask] += gr_tmp
gr = gr.reshape(nf, nloc, ng, 4)
# nf x nloc x ng x 4
gr /= self.nnei
gr1 = gr[:, :, : self.axis_neuron, :]
Expand All @@ -313,6 +327,12 @@ def call(

def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
if not self.type_one_side and self.exclude_types:
for embedding_idx in itertools.product(range(self.ntypes), repeat=2):
# not actually used; to match serilization data from TF to pass the test
if embedding_idx in self.emask:
self.embeddings[embedding_idx].clear()

return {
"@class": "Descriptor",
"type": "se_e2_a",
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,6 @@ def build_type_exclude_mask(
type_ij = type_ij.reshape(nf, nloc * nnei)
mask = self.type_mask[type_ij].reshape(nf, nloc, nnei)
return mask

def __contains__(self, item):
return item in self.exclude_types
9 changes: 9 additions & 0 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,15 @@ def call(self, x):
x = layer(x)
return x

def clear(self):
"""Clear the network parameters to zero."""
for layer in self.layers:
layer.w.fill(0.0)
if layer.b is not None:
layer.b.fill(0.0)
if layer.idt is not None:
layer.idt.fill(0.0)

return NN


Expand Down
36 changes: 27 additions & 9 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import itertools
from typing import (
ClassVar,
Dict,
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
resnet_dt: bool = False,
exclude_types: List[Tuple[int, int]] = [],
old_impl: bool = False,
type_one_side: bool = True,
**kwargs,
):
super().__init__()
Expand All @@ -82,6 +84,7 @@ def __init__(
resnet_dt=resnet_dt,
exclude_types=exclude_types,
old_impl=old_impl,
type_one_side=type_one_side,
**kwargs,
)

Expand Down Expand Up @@ -214,7 +217,7 @@ def serialize(self) -> dict:
},
## to be updated when the options are supported.
"trainable": True,
"type_one_side": True,
"type_one_side": obj.type_one_side,
"spin": None,
}

Expand Down Expand Up @@ -255,6 +258,7 @@ def __init__(
resnet_dt: bool = False,
exclude_types: List[Tuple[int, int]] = [],
old_impl: bool = False,
type_one_side: bool = True,
**kwargs,
):
"""Construct an embedding net of type `se_a`.
Expand All @@ -281,6 +285,7 @@ def __init__(
self.exclude_types = exclude_types
self.ntypes = len(sel)
self.emask = PairExcludeMask(len(sel), exclude_types=exclude_types)
self.type_one_side = type_one_side

self.sel = sel
self.sec = torch.tensor(
Expand All @@ -299,6 +304,10 @@ def __init__(
self.filter_layers = None

if self.old_impl:
if not self.type_one_side:
raise ValueError(
"The old implementation does not support type_one_side=False."
)
filter_layers = []
# TODO: remove
start_index = 0
Expand All @@ -308,12 +317,12 @@ def __init__(
start_index += sel[type_i]
self.filter_layers_old = torch.nn.ModuleList(filter_layers)
else:
ndim = 1 if self.type_one_side else 2
filter_layers = NetworkCollection(
ndim=1, ntypes=len(sel), network_type="embedding_network"
ndim=ndim, ntypes=len(sel), network_type="embedding_network"
)
# TODO: ndim=2 if type_one_side=False
for ii in range(self.ntypes):
filter_layers[(ii,)] = EmbeddingNet(
for embedding_idx in itertools.product(range(self.ntypes), repeat=ndim):
filter_layers[embedding_idx] = EmbeddingNet(
1,
self.filter_neuron,
activation_function=self.activation_function,
Expand Down Expand Up @@ -473,18 +482,27 @@ def forward(
)
# nfnl x nnei
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1)
for ii, ll in enumerate(self.filter_layers.networks):
for embedding_idx, ll in enumerate(self.filter_layers.networks):
if self.type_one_side:
ii = embedding_idx
# torch.jit is not happy with slice(None)
ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device)
else:
# ti: center atom type, ii: neighbor type...
ii = embedding_idx // self.ntypes
ti = embedding_idx % self.ntypes
ti_mask = atype.ravel().eq(ti)
# nfnl x nt
mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]]
mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]]
# nfnl x nt x 4
rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :]
rr = dmatrix[ti_mask, self.sec[ii] : self.sec[ii + 1], :]
rr = rr * mm[:, :, None]
ss = rr[:, :, :1]
# nfnl x nt x ng
gg = ll.forward(ss)
# nfnl x 4 x ng
gr = torch.matmul(rr.permute(0, 2, 1), gg)
xyz_scatter += gr
xyz_scatter[ti_mask] += gr

xyz_scatter /= self.nnei
xyz_scatter_1 = xyz_scatter.permute(0, 2, 1)
Expand Down
11 changes: 2 additions & 9 deletions deepmd/tf/descriptor/se.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,8 @@ def serialize_network(
resnet_dt=resnet_dt,
precision=self.precision.name,
)
for layer in range(len(neuron)):
embeddings[(type_i, type_j)][layer]["w"][:] = 0.0
embeddings[(type_i, type_j)][layer]["b"][:] = 0.0
if embeddings[(type_i, type_j)][layer]["idt"] is not None:
embeddings[(type_i, type_j)][layer]["idt"][:] = 0.0
embeddings[(type_j, type_i)][layer]["w"][:] = 0.0
embeddings[(type_j, type_i)][layer]["b"][:] = 0.0
if embeddings[(type_j, type_i)][layer]["idt"] is not None:
embeddings[(type_j, type_i)][layer]["idt"][:] = 0.0
embeddings[(type_i, type_j)].clear()
embeddings[(type_j, type_i)].clear()

if suffix != "":
embedding_net_pattern = (
Expand Down
15 changes: 13 additions & 2 deletions source/tests/consistent/descriptor/test_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def skip_pt(self) -> bool:
excluded_types,
precision,
) = self.param
return not type_one_side or CommonTest.skip_pt
return CommonTest.skip_pt

@property
def skip_dp(self) -> bool:
Expand All @@ -81,7 +81,7 @@ def skip_dp(self) -> bool:
excluded_types,
precision,
) = self.param
return not type_one_side or CommonTest.skip_dp
return CommonTest.skip_dp

tf_class = DescrptSeATF
dp_class = DescrptSeADP
Expand Down Expand Up @@ -121,6 +121,17 @@ def setUp(self):
dtype=GLOBAL_NP_FLOAT_PRECISION,
)
self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)
# TF se_e2_a type_one_side=False requires atype sorted
(
resnet_dt,
type_one_side,
excluded_types,
precision,
) = self.param
if not type_one_side:
idx = np.argsort(self.atype)
self.atype = self.atype[idx]
self.coords = self.coords.reshape(-1, 3)[idx].ravel()

def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]:
return self.build_tf_descriptor(
Expand Down

0 comments on commit 4f70073

Please sign in to comment.