Skip to content

Commit

Permalink
fix: consistent DPA-1 model (#4320)
Browse files Browse the repository at this point in the history
Fix #4022.
Note that `smooth_type_embedding==True` is not consistent between TF and
others.
Also, fix several issues.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Enhanced configurability of descriptors with new optional parameters
for type mapping and type count.
- Introduction of a new class `DescrptSeAttenV2` for advanced attention
mechanisms.
- Added a new unit test framework for validating energy models across
multiple backends.

- **Bug Fixes**
- Improved error handling in descriptor serialization methods to prevent
unsupported operations.

- **Documentation**
- Updated backend documentation to include JAX support and clarify file
extensions for various backends.

- **Style**
	- Enhanced readability of error messages in fitting classes.

- **Tests**
- Comprehensive unit tests added for energy models across different
machine learning frameworks.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Nov 8, 2024
1 parent 3236db5 commit 0199ad5
Show file tree
Hide file tree
Showing 13 changed files with 389 additions and 46 deletions.
6 changes: 6 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class DescrptHybrid(BaseDescriptor, NativeOP):
def __init__(
self,
list: list[Union[BaseDescriptor, dict[str, Any]]],
type_map: Optional[list[str]] = None,
ntypes: Optional[int] = None, # to be compat with input
) -> None:
super().__init__()
# warning: list is conflict with built-in list
Expand All @@ -56,6 +58,10 @@ def __init__(
if isinstance(ii, BaseDescriptor):
formatted_descript_list.append(ii)
elif isinstance(ii, dict):
ii = ii.copy()
# only pass if not already set
ii.setdefault("type_map", type_map)
ii.setdefault("ntypes", ntypes)
formatted_descript_list.append(BaseDescriptor(**ii))
else:
raise NotImplementedError
Expand Down
14 changes: 4 additions & 10 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
)
Expand Down Expand Up @@ -39,16 +36,13 @@ def get_standard_model(data: dict) -> EnergyModel:
data : dict
The data to construct the model.
"""
descriptor_type = data["descriptor"].pop("type")
data["descriptor"]["type_map"] = data["type_map"]
data["descriptor"]["ntypes"] = len(data["type_map"])
fitting_type = data["fitting_net"].pop("type")
data["fitting_net"]["type_map"] = data["type_map"]
if descriptor_type == "se_e2_a":
descriptor = DescrptSeA(
**data["descriptor"],
)
else:
raise ValueError(f"Unknown descriptor type {descriptor_type}")
descriptor = BaseDescriptor(
**data["descriptor"],
)
if fitting_type == "ener":
fitting = EnergyFittingNet(
ntypes=descriptor.get_ntypes(),
Expand Down
4 changes: 4 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from deepmd.jax.descriptor.hybrid import (
DescrptHybrid,
)
from deepmd.jax.descriptor.se_atten_v2 import (
DescrptSeAttenV2,
)
from deepmd.jax.descriptor.se_e2_a import (
DescrptSeA,
)
Expand All @@ -27,6 +30,7 @@
"DescrptSeT",
"DescrptSeTTebd",
"DescrptDPA1",
"DescrptSeAttenV2",
"DescrptDPA2",
"DescrptHybrid",
]
1 change: 1 addition & 0 deletions deepmd/jax/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_standard_model(data: dict):
data = deepcopy(data)
descriptor_type = data["descriptor"].pop("type")
data["descriptor"]["type_map"] = data["type_map"]
data["descriptor"]["ntypes"] = len(data["type_map"])
fitting_type = data["fitting_net"].pop("type")
data["fitting_net"]["type_map"] = data["type_map"]
descriptor = BaseDescriptor.get_class_by_type(descriptor_type)(
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ def _forward_common(

if nd != self.dim_descrpt:
raise ValueError(
"get an input descriptor of dim {nd},"
"which is not consistent with {self.dim_descrpt}."
f"get an input descriptor of dim {nd},"
f"which is not consistent with {self.dim_descrpt}."
)
# check fparam dim, concate to input descriptor
if self.numb_fparam > 0:
Expand Down
6 changes: 6 additions & 0 deletions deepmd/tf/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@ def update_sel(
return local_jdata_cpy, min_nbor_dist

def serialize(self, suffix: str = "") -> dict:
if hasattr(self, "type_embedding"):
raise NotImplementedError("hybrid + type embedding is not supported")
return {
"@class": "Descriptor",
"type": "hybrid",
Expand All @@ -485,4 +487,8 @@ def deserialize(cls, data: dict, suffix: str = "") -> "DescrptHybrid":
for idx, ii in enumerate(data["list"])
],
)
# search for type embedding
for ii in obj.descrpt_list:
if hasattr(ii, "type_embedding"):
raise NotImplementedError("hybrid + type embedding is not supported")
return obj
123 changes: 96 additions & 27 deletions deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ def __init__(
if scaling_factor != 1.0:
raise NotImplementedError("scaling_factor is not supported.")
if not normalize:
raise NotImplementedError("normalize is not supported.")
raise NotImplementedError("Disabling normalize is not supported.")
if temperature is not None:
raise NotImplementedError("temperature is not supported.")
if not concat_output_tebd:
raise NotImplementedError("concat_output_tebd is not supported.")
raise NotImplementedError("Disbaling concat_output_tebd is not supported.")
if env_protection != 0.0:
raise NotImplementedError("env_protection != 0.0 is not supported.")
# to keep consistent with default value in this backends
Expand Down Expand Up @@ -1866,7 +1866,11 @@ def deserialize(cls, data: dict, suffix: str = ""):
if cls is not DescrptSeAtten:
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
if data["smooth_type_embedding"]:
raise RuntimeError(
"The implementation for smooth_type_embedding is inconsistent with other backends"
)
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
data.pop("type")
embedding_net_variables = cls.deserialize_network(
Expand All @@ -1878,10 +1882,13 @@ def deserialize(cls, data: dict, suffix: str = ""):
data.pop("env_mat")
variables = data.pop("@variables")
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
raise ValueError(
"Deserialization is unsupported for `tebd_input_mode='strip'` in the native model."
)
type_embedding = TypeEmbedNet.deserialize(
data.pop("type_embedding"), suffix=suffix
)
if "use_tebd_bias" not in data:
# v1 compatibility
data["use_tebd_bias"] = True
type_embedding.use_tebd_bias = data.pop("use_tebd_bias")
descriptor = cls(**data)
descriptor.embedding_net_variables = embedding_net_variables
descriptor.attention_layer_variables = attention_layer_variables
Expand All @@ -1891,6 +1898,17 @@ def deserialize(cls, data: dict, suffix: str = ""):
descriptor.dstd = variables["dstd"].reshape(
descriptor.ntypes, descriptor.ndescrpt
)
descriptor.type_embedding = type_embedding
if tebd_input_mode in ["strip"]:
type_one_side = data["type_one_side"]
two_side_embeeding_net_variables = cls.deserialize_network_strip(
data.pop("embeddings_strip"),
suffix=suffix,
type_one_side=type_one_side,
)
descriptor.two_side_embeeding_net_variables = (
two_side_embeeding_net_variables
)
return descriptor

def serialize(self, suffix: str = "") -> dict:
Expand All @@ -1906,10 +1924,9 @@ def serialize(self, suffix: str = "") -> dict:
dict
The serialized data
"""
if self.stripped_type_embedding and type(self) is DescrptSeAtten:
# only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'
raise NotImplementedError(
"serialization is unsupported by the native model when tebd_input_mode=='strip'"
if self.smooth:
raise RuntimeError(
"The implementation for smooth_type_embedding is inconsistent with other backends"
)
# todo support serialization when tebd_input_mode=='strip' and type_one_side is True
if self.stripped_type_embedding and self.type_one_side:
Expand All @@ -1927,10 +1944,18 @@ def serialize(self, suffix: str = "") -> dict:
assert self.davg is not None
assert self.dstd is not None

tebd_dim = self.type_embedding.neuron[0]
if self.tebd_input_mode in ["concat"]:
if not self.type_one_side:
embd_input_dim = 1 + tebd_dim * 2
else:
embd_input_dim = 1 + tebd_dim
else:
embd_input_dim = 1
data = {
"@class": "Descriptor",
"type": "se_atten",
"@version": 1,
"type": "dpa1",
"@version": 2,
"rcut": self.rcut_r,
"rcut_smth": self.rcut_r_smth,
"sel": self.sel_a,
Expand All @@ -1952,9 +1977,7 @@ def serialize(self, suffix: str = "") -> dict:
"embeddings": self.serialize_network(
ntypes=self.ntypes,
ndim=0,
in_dim=1
if not hasattr(self, "embd_input_dim")
else self.embd_input_dim,
in_dim=embd_input_dim,
neuron=self.filter_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.filter_resnet_dt,
Expand Down Expand Up @@ -1986,17 +2009,23 @@ def serialize(self, suffix: str = "") -> dict:
"type_one_side": self.type_one_side,
"spin": self.spin,
}
data["type_embedding"] = self.type_embedding.serialize(suffix=suffix)
data["use_tebd_bias"] = self.type_embedding.use_tebd_bias
data["tebd_dim"] = tebd_dim
if len(self.type_embedding.neuron) > 1:
raise NotImplementedError(
"Only support single layer type embedding network"
)
if self.tebd_input_mode in ["strip"]:
assert (
type(self) is not DescrptSeAtten
), "only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'"
# assert (
# type(self) is not DescrptSeAtten
# ), "only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'"
data.update(
{
"embeddings_strip": self.serialize_network_strip(
ntypes=self.ntypes,
ndim=0,
in_dim=2
* self.tebd_dim, # only DescrptDPA1Compat has this attribute
in_dim=2 * tebd_dim,
neuron=self.filter_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.filter_resnet_dt,
Expand All @@ -2006,8 +2035,54 @@ def serialize(self, suffix: str = "") -> dict:
)
}
)
# default values
data.update(
{
"scaling_factor": 1.0,
"normalize": True,
"temperature": None,
"concat_output_tebd": True,
"use_econf_tebd": False,
}
)
data["attention_layers"] = self.update_attention_layers_serialize(
data["attention_layers"]
)
return data

def update_attention_layers_serialize(self, data: dict):
"""Update the serialized data to be consistent with other backend references."""
new_dict = {
"@class": "NeighborGatedAttention",
"@version": 1,
"scaling_factor": 1.0,
"normalize": True,
"temperature": None,
}
new_dict.update(data)
update_info = {
"nnei": self.nnei_a,
"embed_dim": self.filter_neuron[-1],
"hidden_dim": self.att_n,
"dotr": self.attn_dotr,
"do_mask": self.attn_mask,
"scaling_factor": 1.0,
"normalize": True,
"temperature": None,
"precision": self.filter_precision.name,
}
for layer_idx in range(self.attn_layer):
new_dict["attention_layers"][layer_idx].update(update_info)
new_dict["attention_layers"][layer_idx]["attention_layer"].update(
update_info
)
new_dict["attention_layers"][layer_idx]["attention_layer"].update(
{
"num_heads": 1,
}
)
return new_dict


class DescrptDPA1Compat(DescrptSeAtten):
r"""Consistent version of the model for testing with other backend references.
Expand Down Expand Up @@ -2433,17 +2508,11 @@ def serialize(self, suffix: str = "") -> dict:
{
"type": "dpa1",
"@version": 2,
"tebd_dim": self.tebd_dim,
"scaling_factor": self.scaling_factor,
"normalize": self.normalize,
"temperature": self.temperature,
"concat_output_tebd": self.concat_output_tebd,
"use_econf_tebd": self.use_econf_tebd,
"use_tebd_bias": self.use_tebd_bias,
"type_embedding": self.type_embedding.serialize(suffix),
}
)
data["attention_layers"] = self.update_attention_layers_serialize(
data["attention_layers"]
)
return data
16 changes: 15 additions & 1 deletion deepmd/tf/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
Optional,
)

from deepmd.tf.utils.type_embed import (
TypeEmbedNet,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -127,10 +130,13 @@ def deserialize(cls, data: dict, suffix: str = ""):
Model
The deserialized model
"""
raise RuntimeError(
"The implementation for smooth_type_embedding is inconsistent with other backends"
)
if cls is not DescrptSeAttenV2:
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
data.pop("type")
embedding_net_variables = cls.deserialize_network(
Expand All @@ -147,6 +153,13 @@ def deserialize(cls, data: dict, suffix: str = ""):
suffix=suffix,
type_one_side=type_one_side,
)
type_embedding = TypeEmbedNet.deserialize(
data.pop("type_embedding"), suffix=suffix
)
if "use_tebd_bias" not in data:
# v1 compatibility
data["use_tebd_bias"] = True
type_embedding.use_tebd_bias = data.pop("use_tebd_bias")
descriptor = cls(**data)
descriptor.embedding_net_variables = embedding_net_variables
descriptor.attention_layer_variables = attention_layer_variables
Expand All @@ -157,6 +170,7 @@ def deserialize(cls, data: dict, suffix: str = ""):
descriptor.dstd = variables["dstd"].reshape(
descriptor.ntypes, descriptor.ndescrpt
)
descriptor.type_embedding = type_embedding
return descriptor

def serialize(self, suffix: str = "") -> dict:
Expand Down
6 changes: 5 additions & 1 deletion deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def __init__(
len(self.layer_name) == len(self.n_neuron) + 1
), "length of layer_name should be that of n_neuron + 1"
self.mixed_types = mixed_types
self.tebd_dim = 0

def get_numb_fparam(self) -> int:
"""Get the number of frame parameters."""
Expand Down Expand Up @@ -754,6 +755,8 @@ def build(
outs = tf.reshape(outs, [-1])

tf.summary.histogram("fitting_net_output", outs)
# recover original dim_descrpt, which needs to be serialized
self.dim_descrpt = original_dim_descrpt
return tf.reshape(outs, [-1])

def init_variables(
Expand Down Expand Up @@ -908,7 +911,7 @@ def serialize(self, suffix: str = "") -> dict:
"@version": 2,
"var_name": "energy",
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
"dim_descrpt": self.dim_descrpt + self.tebd_dim,
"mixed_types": self.mixed_types,
"dim_out": 1,
"neuron": self.n_neuron,
Expand All @@ -930,6 +933,7 @@ def serialize(self, suffix: str = "") -> dict:
ndim=0 if self.mixed_types else 1,
in_dim=(
self.dim_descrpt
+ self.tebd_dim
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
),
Expand Down
Loading

0 comments on commit 0199ad5

Please sign in to comment.