Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support stripped type embedding in DPA1 of PT/DP #3712

Merged
merged 18 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 75 additions & 17 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@
tebd_dim: int
Dimension of the type embedding
tebd_input_mode: str
The way to mix the type embeddings. Supported options are `concat`.
(TODO need to support stripped_type_embedding option)
The input mode of the type embedding. Supported modes are ["concat", "strip"].
- "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network.
- "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output.
iProzd marked this conversation as resolved.
Show resolved Hide resolved
resnet_dt: bool
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
Expand Down Expand Up @@ -182,16 +183,19 @@
Whether to use smooth process in attention weights calculation.
concat_output_tebd: bool
Whether to concat type embedding at the output of the descriptor.
stripped_type_embedding: bool, Optional
(Deprecated, kept only for compatibility.)
Whether to strip the type embedding into a separate embedding network.
Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'.
Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'.
The default value is `None`, which means the `tebd_input_mode` setting will be used instead.
iProzd marked this conversation as resolved.
Show resolved Hide resolved
spin
(Only support None to keep consistent with other backend references.)
(Not used in this version. Not-none option is not implemented.)
The old implementation of deepspin.

Limitations
-----------
The currently implementation does not support the following features
1. tebd_input_mode != 'concat'

The currently implementation will not support the following deprecated features
1. spin is not None
2. attn_mask == True
Expand Down Expand Up @@ -233,19 +237,21 @@
smooth_type_embedding: bool = True,
concat_output_tebd: bool = True,
spin: Optional[Any] = None,
stripped_type_embedding: Optional[bool] = None,
# consistent with argcheck, not used though
seed: Optional[int] = None,
) -> None:
## seed, uniform_seed, multi_task, not included.
# Ensure compatibility with the deprecated stripped_type_embedding option.
if stripped_type_embedding is not None:
# Use the user-set stripped_type_embedding parameter first
tebd_input_mode = "strip" if stripped_type_embedding else "concat"

Check warning on line 248 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L248

Added line #L248 was not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved
if spin is not None:
raise NotImplementedError("old implementation of spin is not supported.")
if attn_mask:
raise NotImplementedError(
"old implementation of attn_mask is not supported."
)
# TODO
if tebd_input_mode != "concat":
raise NotImplementedError("tebd_input_mode != 'concat' not implemented")
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down Expand Up @@ -290,25 +296,38 @@
activation_function="Linear",
precision=precision,
)
self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2
iProzd marked this conversation as resolved.
Show resolved Hide resolved
if self.tebd_input_mode in ["concat"]:
if not self.type_one_side:
in_dim = 1 + self.tebd_dim * 2
else:
in_dim = 1 + self.tebd_dim
self.embd_input_dim = 1 + self.tebd_dim_input
iProzd marked this conversation as resolved.
Show resolved Hide resolved
else:
in_dim = 1
self.embd_input_dim = 1
self.embeddings = NetworkCollection(
ndim=0,
ntypes=self.ntypes,
network_type="embedding_network",
)
self.embeddings[0] = EmbeddingNet(
in_dim,
self.embd_input_dim,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)
if self.tebd_input_mode in ["strip"]:
self.embeddings_strip = NetworkCollection(
ndim=0,
ntypes=self.ntypes,
network_type="embedding_network",
)
self.embeddings_strip[0] = EmbeddingNet(
self.tebd_dim_input,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)
else:
self.embeddings_strip = None
self.dpa1_attention = NeighborGatedAttention(
self.attn_layer,
self.nnei,
Expand Down Expand Up @@ -410,6 +429,18 @@
gg = self.embeddings[embedding_idx].call(ss)
return gg

def cal_g_strip(
self,
ss,
embedding_idx,
):
assert self.embeddings_strip is not None
nfnl, nnei = ss.shape[0:2]
ss = ss.reshape(nfnl, nnei, -1)
# nfnl x nnei x ng
gg = self.embeddings_strip[embedding_idx].call(ss)
return gg

def reinit_exclude(
self,
exclude_types: List[Tuple[int, int]] = [],
Expand Down Expand Up @@ -500,11 +531,28 @@
else:
# nfnl x nnei x (1 + tebd_dim)
ss = np.concatenate([ss, atype_embd_nlist], axis=-1)
# calculate gg
# nfnl x nnei x ng
gg = self.cal_g(ss, 0)
elif self.tebd_input_mode in ["strip"]:
# nfnl x nnei x ng
gg_s = self.cal_g(ss, 0)
assert self.embeddings_strip is not None
if not self.type_one_side:
# nfnl x nnei x (tebd_dim * 2)
tt = np.concatenate([atype_embd_nlist, atype_embd_nnei], axis=-1)
else:
# nfnl x nnei x tebd_dim
tt = atype_embd_nlist
# nfnl x nnei x ng
gg_t = self.cal_g_strip(tt, 0)
if self.smooth:
gg_t = gg_t * sw.reshape(-1, self.nnei, 1)
# nfnl x nnei x ng
gg = gg_s * gg_t + gg_s
else:
raise NotImplementedError

# calculate gg
gg = self.cal_g(ss, 0)
input_r = dmatrix.reshape(-1, nnei, 4)[:, :, 1:4] / np.maximum(
np.linalg.norm(
dmatrix.reshape(-1, nnei, 4)[:, :, 1:4], axis=-1, keepdims=True
Expand Down Expand Up @@ -532,7 +580,7 @@

def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
return {
data = {
"@class": "Descriptor",
"type": "dpa1",
"@version": 1,
Expand Down Expand Up @@ -575,6 +623,9 @@
"trainable": True,
"spin": None,
}
if self.tebd_input_mode in ["strip"]:
data.update({"embeddings_strip": self.embeddings_strip.serialize()})
return data

@classmethod
def deserialize(cls, data: dict) -> "DescrptDPA1":
Expand All @@ -588,11 +639,18 @@
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
env_mat = data.pop("env_mat")
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
embeddings_strip = data.pop("embeddings_strip")
else:
embeddings_strip = None
obj = cls(**data)

obj["davg"] = variables["davg"]
obj["dstd"] = variables["dstd"]
obj.embeddings = NetworkCollection.deserialize(embeddings)
if tebd_input_mode in ["strip"]:
obj.embeddings_strip = NetworkCollection.deserialize(embeddings_strip)
obj.type_embedding = TypeEmbedNet.deserialize(type_embedding)
obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers)
return obj
Expand Down
42 changes: 28 additions & 14 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@
tebd_dim: int
Dimension of the type embedding
tebd_input_mode: str
The way to mix the type embeddings. Supported options are `concat`.
(TODO need to support stripped_type_embedding option)
The input mode of the type embedding. Supported modes are ["concat", "strip"].
- "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network.
- "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output.
resnet_dt: bool
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
Expand Down Expand Up @@ -165,16 +166,19 @@
Whether to use smooth process in attention weights calculation.
concat_output_tebd: bool
Whether to concat type embedding at the output of the descriptor.
stripped_type_embedding: bool, Optional
(Deprecated, kept only for compatibility.)
Whether to strip the type embedding into a separate embedding network.
Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'.
Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'.
The default value is `None`, which means the `tebd_input_mode` setting will be used instead.
spin
(Only support None to keep consistent with other backend references.)
(Not used in this version. Not-none option is not implemented.)
The old implementation of deepspin.

Limitations
-----------
The currently implementation does not support the following features
1. tebd_input_mode != 'concat'

The currently implementation will not support the following deprecated features
1. spin is not None
2. attn_mask == True
Expand All @@ -196,8 +200,7 @@
axis_neuron: int = 16,
tebd_dim: int = 8,
tebd_input_mode: str = "concat",
# set_davg_zero: bool = False,
set_davg_zero: bool = True, # TODO
set_davg_zero: bool = True,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
attn: int = 128,
attn_layer: int = 2,
attn_dotr: bool = True,
Expand All @@ -216,25 +219,24 @@
ln_eps: Optional[float] = 1e-5,
smooth_type_embedding: bool = True,
type_one_side: bool = False,
stripped_type_embedding: Optional[bool] = None,
# not implemented
stripped_type_embedding: bool = False,
spin=None,
type: Optional[str] = None,
seed: Optional[int] = None,
old_impl: bool = False,
):
super().__init__()
if stripped_type_embedding:
raise NotImplementedError("stripped_type_embedding is not supported.")
# Ensure compatibility with the deprecated stripped_type_embedding option.
if stripped_type_embedding is not None:
# Use the user-set stripped_type_embedding parameter first
tebd_input_mode = "strip" if stripped_type_embedding else "concat"

Check warning on line 233 in deepmd/pt/model/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa1.py#L233

Added line #L233 was not covered by tests
if spin is not None:
raise NotImplementedError("old implementation of spin is not supported.")
if attn_mask:
raise NotImplementedError(
"old implementation of attn_mask is not supported."
)
# TODO
if tebd_input_mode != "concat":
raise NotImplementedError("tebd_input_mode != 'concat' not implemented")
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down Expand Up @@ -377,7 +379,7 @@

def serialize(self) -> dict:
obj = self.se_atten
return {
data = {
"@class": "Descriptor",
"type": "dpa1",
"@version": 1,
Expand Down Expand Up @@ -420,6 +422,9 @@
"trainable": True,
"spin": None,
}
if obj.tebd_input_mode in ["strip"]:
data.update({"embeddings_strip": obj.filter_layers_strip.serialize()})
return data

@classmethod
def deserialize(cls, data: dict) -> "DescrptDPA1":
Expand All @@ -432,6 +437,11 @@
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
env_mat = data.pop("env_mat")
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
embeddings_strip = data.pop("embeddings_strip")
else:
embeddings_strip = None
obj = cls(**data)

def t_cvt(xx):
Expand All @@ -443,6 +453,10 @@
obj.se_atten["davg"] = t_cvt(variables["davg"])
obj.se_atten["dstd"] = t_cvt(variables["dstd"])
obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings)
if tebd_input_mode in ["strip"]:
obj.se_atten.filter_layers_strip = NetworkCollection.deserialize(
embeddings_strip
)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(
attention_layers
)
Expand Down
Loading