Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support stripped type embedding in DPA1 of PT/DP #3712

Merged
merged 18 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 64 additions & 17 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ class DescrptDPA1(NativeOP, BaseDescriptor):
tebd_dim: int
Dimension of the type embedding
tebd_input_mode: str
The way to mix the type embeddings. Supported options are `concat`.
(TODO need to support stripped_type_embedding option)
The input mode of the type embedding. Supported modes are ["concat", "strip"].
- "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network.
- "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output.
iProzd marked this conversation as resolved.
Show resolved Hide resolved
resnet_dt: bool
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
Expand Down Expand Up @@ -189,9 +190,6 @@ class DescrptDPA1(NativeOP, BaseDescriptor):

Limitations
-----------
The currently implementation does not support the following features
1. tebd_input_mode != 'concat'

The currently implementation will not support the following deprecated features
1. spin is not None
2. attn_mask == True
Expand Down Expand Up @@ -243,9 +241,6 @@ def __init__(
raise NotImplementedError(
"old implementation of attn_mask is not supported."
)
# TODO
if tebd_input_mode != "concat":
raise NotImplementedError("tebd_input_mode != 'concat' not implemented")
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down Expand Up @@ -290,25 +285,38 @@ def __init__(
activation_function="Linear",
precision=precision,
)
self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2
iProzd marked this conversation as resolved.
Show resolved Hide resolved
if self.tebd_input_mode in ["concat"]:
if not self.type_one_side:
in_dim = 1 + self.tebd_dim * 2
else:
in_dim = 1 + self.tebd_dim
self.embd_input_dim = 1 + self.tebd_dim_input
iProzd marked this conversation as resolved.
Show resolved Hide resolved
else:
in_dim = 1
self.embd_input_dim = 1
self.embeddings = NetworkCollection(
ndim=0,
ntypes=self.ntypes,
network_type="embedding_network",
)
self.embeddings[0] = EmbeddingNet(
in_dim,
self.embd_input_dim,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)
if self.tebd_input_mode in ["strip"]:
self.embeddings_strip = NetworkCollection(
ndim=0,
ntypes=self.ntypes,
network_type="embedding_network",
)
self.embeddings_strip[0] = EmbeddingNet(
self.tebd_dim_input,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)
else:
self.embeddings_strip = None
self.dpa1_attention = NeighborGatedAttention(
self.attn_layer,
self.nnei,
Expand Down Expand Up @@ -410,6 +418,18 @@ def cal_g(
gg = self.embeddings[embedding_idx].call(ss)
return gg

def cal_g_strip(
self,
ss,
embedding_idx,
):
assert self.embeddings_strip is not None
nfnl, nnei = ss.shape[0:2]
ss = ss.reshape(nfnl, nnei, -1)
# nfnl x nnei x ng
gg = self.embeddings_strip[embedding_idx].call(ss)
return gg

def reinit_exclude(
self,
exclude_types: List[Tuple[int, int]] = [],
Expand Down Expand Up @@ -500,11 +520,28 @@ def call(
else:
# nfnl x nnei x (1 + tebd_dim)
ss = np.concatenate([ss, atype_embd_nlist], axis=-1)
# calculate gg
# nfnl x nnei x ng
gg = self.cal_g(ss, 0)
elif self.tebd_input_mode in ["strip"]:
# nfnl x nnei x ng
gg_s = self.cal_g(ss, 0)
assert self.embeddings_strip is not None
if not self.type_one_side:
# nfnl x nnei x (tebd_dim * 2)
tt = np.concatenate([atype_embd_nlist, atype_embd_nnei], axis=-1)
else:
# nfnl x nnei x tebd_dim
tt = atype_embd_nlist
# nfnl x nnei x ng
gg_t = self.cal_g_strip(tt, 0)
if self.smooth:
gg_t = gg_t * sw.reshape(-1, self.nnei, 1)
# nfnl x nnei x ng
gg = gg_s * gg_t + gg_s
else:
raise NotImplementedError

# calculate gg
gg = self.cal_g(ss, 0)
input_r = dmatrix.reshape(-1, nnei, 4)[:, :, 1:4] / np.maximum(
np.linalg.norm(
dmatrix.reshape(-1, nnei, 4)[:, :, 1:4], axis=-1, keepdims=True
Expand Down Expand Up @@ -532,7 +569,7 @@ def call(

def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
return {
data = {
"@class": "Descriptor",
"type": "dpa1",
"@version": 1,
Expand Down Expand Up @@ -575,6 +612,9 @@ def serialize(self) -> dict:
"trainable": True,
"spin": None,
}
if self.tebd_input_mode in ["strip"]:
data.update({"embeddings_strip": self.embeddings_strip.serialize()})
return data

@classmethod
def deserialize(cls, data: dict) -> "DescrptDPA1":
Expand All @@ -588,11 +628,18 @@ def deserialize(cls, data: dict) -> "DescrptDPA1":
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
env_mat = data.pop("env_mat")
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
embeddings_strip = data.pop("embeddings_strip")
else:
embeddings_strip = None
obj = cls(**data)

obj["davg"] = variables["davg"]
obj["dstd"] = variables["dstd"]
obj.embeddings = NetworkCollection.deserialize(embeddings)
if tebd_input_mode in ["strip"]:
obj.embeddings_strip = NetworkCollection.deserialize(embeddings_strip)
obj.type_embedding = TypeEmbedNet.deserialize(type_embedding)
obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers)
return obj
Expand Down
31 changes: 17 additions & 14 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module):
tebd_dim: int
Dimension of the type embedding
tebd_input_mode: str
The way to mix the type embeddings. Supported options are `concat`.
(TODO need to support stripped_type_embedding option)
The input mode of the type embedding. Supported modes are ["concat", "strip"].
- "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network.
- "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output.
resnet_dt: bool
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
Expand Down Expand Up @@ -172,9 +173,6 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module):

Limitations
-----------
The currently implementation does not support the following features
1. tebd_input_mode != 'concat'

The currently implementation will not support the following deprecated features
1. spin is not None
2. attn_mask == True
Expand All @@ -196,8 +194,7 @@ def __init__(
axis_neuron: int = 16,
tebd_dim: int = 8,
tebd_input_mode: str = "concat",
# set_davg_zero: bool = False,
set_davg_zero: bool = True, # TODO
set_davg_zero: bool = True,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
attn: int = 128,
attn_layer: int = 2,
attn_dotr: bool = True,
Expand All @@ -217,24 +214,18 @@ def __init__(
smooth_type_embedding: bool = True,
type_one_side: bool = False,
# not implemented
stripped_type_embedding: bool = False,
spin=None,
type: Optional[str] = None,
seed: Optional[int] = None,
old_impl: bool = False,
):
super().__init__()
if stripped_type_embedding:
raise NotImplementedError("stripped_type_embedding is not supported.")
if spin is not None:
raise NotImplementedError("old implementation of spin is not supported.")
if attn_mask:
raise NotImplementedError(
"old implementation of attn_mask is not supported."
)
# TODO
if tebd_input_mode != "concat":
raise NotImplementedError("tebd_input_mode != 'concat' not implemented")
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down Expand Up @@ -377,7 +368,7 @@ def set_stat_mean_and_stddev(

def serialize(self) -> dict:
obj = self.se_atten
return {
data = {
"@class": "Descriptor",
"type": "dpa1",
"@version": 1,
Expand Down Expand Up @@ -420,6 +411,9 @@ def serialize(self) -> dict:
"trainable": True,
"spin": None,
}
if obj.tebd_input_mode in ["strip"]:
data.update({"embeddings_strip": obj.filter_layers_strip.serialize()})
return data

@classmethod
def deserialize(cls, data: dict) -> "DescrptDPA1":
Expand All @@ -432,6 +426,11 @@ def deserialize(cls, data: dict) -> "DescrptDPA1":
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
env_mat = data.pop("env_mat")
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
embeddings_strip = data.pop("embeddings_strip")
else:
embeddings_strip = None
obj = cls(**data)

def t_cvt(xx):
Expand All @@ -443,6 +442,10 @@ def t_cvt(xx):
obj.se_atten["davg"] = t_cvt(variables["davg"])
obj.se_atten["dstd"] = t_cvt(variables["dstd"])
obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings)
if tebd_input_mode in ["strip"]:
obj.se_atten.filter_layers_strip = NetworkCollection.deserialize(
embeddings_strip
)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(
attention_layers
)
Expand Down
52 changes: 42 additions & 10 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ def __init__(
tebd_dim : int
Dimension of the type embedding
tebd_input_mode : str
The way to mix the type embeddings. Supported options are `concat`.
(TODO need to support stripped_type_embedding option)
The input mode of the type embedding. Supported modes are ["concat", "strip"].
- "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network.
- "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output.
resnet_dt : bool
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
Expand Down Expand Up @@ -191,6 +192,9 @@ def __init__(
# order matters, placed after the assignment of self.ntypes
self.reinit_exclude(exclude_types)
if self.old_impl:
assert self.tebd_input_mode in [
"concat"
], "Old implementation does not support tebd_input_mode != 'concat'."
self.dpa1_attention = NeighborWiseAttention(
self.attn_layer,
self.nnei,
Expand Down Expand Up @@ -230,16 +234,15 @@ def __init__(
)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2
if self.tebd_input_mode in ["concat"]:
if not self.type_one_side:
self.embd_input_dim = 1 + self.tebd_dim * 2
else:
self.embd_input_dim = 1 + self.tebd_dim
self.embd_input_dim = 1 + self.tebd_dim_input
else:
self.embd_input_dim = 1

self.filter_layers_old = None
self.filter_layers = None
self.filter_layers_strip = None
if self.old_impl:
filter_layers = []
one = TypeFilter(
Expand All @@ -265,6 +268,18 @@ def __init__(
resnet_dt=self.resnet_dt,
)
self.filter_layers = filter_layers
if self.tebd_input_mode in ["strip"]:
filter_layers_strip = NetworkCollection(
ndim=0, ntypes=self.ntypes, network_type="embedding_network"
)
filter_layers_strip[0] = EmbeddingNet(
self.tebd_dim_input,
self.filter_neuron,
activation_function=self.activation_function,
precision=self.precision,
resnet_dt=self.resnet_dt,
)
self.filter_layers_strip = filter_layers_strip
iProzd marked this conversation as resolved.
Show resolved Hide resolved
self.stats = None

def get_rcut(self) -> float:
Expand Down Expand Up @@ -498,19 +513,36 @@ def forward(
rr = dmatrix
rr = rr * exclude_mask[:, :, None]
ss = rr[:, :, :1]
nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim)
atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim)
if self.tebd_input_mode in ["concat"]:
nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim)
atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim)
if not self.type_one_side:
# nfnl x nnei x (1 + tebd_dim * 2)
ss = torch.concat([ss, nlist_tebd, atype_tebd], dim=2)
else:
# nfnl x nnei x (1 + tebd_dim)
ss = torch.concat([ss, nlist_tebd], dim=2)
# nfnl x nnei x ng
gg = self.filter_layers.networks[0](ss)
elif self.tebd_input_mode in ["strip"]:
# nfnl x nnei x ng
gg_s = self.filter_layers.networks[0](ss)
assert self.filter_layers_strip is not None
if not self.type_one_side:
# nfnl x nnei x (tebd_dim * 2)
tt = torch.concat([nlist_tebd, atype_tebd], dim=2)
else:
# nfnl x nnei x tebd_dim
tt = nlist_tebd
# nfnl x nnei x ng
gg_t = self.filter_layers_strip.networks[0](tt)
if self.smooth:
gg_t = gg_t * sw.reshape(-1, self.nnei, 1)
# nfnl x nnei x ng
gg = gg_s * gg_t + gg_s
else:
raise NotImplementedError
# nfnl x nnei x ng
gg = self.filter_layers._networks[0](ss)

input_r = torch.nn.functional.normalize(
dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1
)
Expand Down
Loading