Skip to content

Commit

Permalink
move stripped_type_embedding argcheck into class
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed May 6, 2024
1 parent 9334c0a commit 3a98a65
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 36 deletions.
9 changes: 9 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ class DescrptDPA1(NativeOP, BaseDescriptor):
Whether to use smooth process in attention weights calculation.
concat_output_tebd: bool
Whether to concat type embedding at the output of the descriptor.
stripped_type_embedding: bool, Optional
(Deprecated, kept only for compatibility.)
Whether to strip the type embedding into a separated embedding network.
Setting this to `True` is equivalent to setting `tebd_input_mode` to 'strip'.
spin
(Only support None to keep consistent with other backend references.)
(Not used in this version. Not-none option is not implemented.)
Expand Down Expand Up @@ -231,10 +235,15 @@ def __init__(
smooth_type_embedding: bool = True,
concat_output_tebd: bool = True,
spin: Optional[Any] = None,
stripped_type_embedding: Optional[bool] = None,
# consistent with argcheck, not used though
seed: Optional[int] = None,
) -> None:
## seed, uniform_seed, multi_task, not included.
# Ensure compatibility with the deprecated stripped_type_embedding option.
if stripped_type_embedding is not None:
# Use the user-set stripped_type_embedding parameter first
tebd_input_mode = "strip" if stripped_type_embedding else "concat"

Check warning on line 246 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L246

Added line #L246 was not covered by tests
if spin is not None:
raise NotImplementedError("old implementation of spin is not supported.")
if attn_mask:
Expand Down
9 changes: 9 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module):
Whether to use smooth process in attention weights calculation.
concat_output_tebd: bool
Whether to concat type embedding at the output of the descriptor.
stripped_type_embedding: bool, Optional
(Deprecated, kept only for compatibility.)
Whether to strip the type embedding into a separated embedding network.
Setting this to `True` is equivalent to setting `tebd_input_mode` to 'strip'.
spin
(Only support None to keep consistent with other backend references.)
(Not used in this version. Not-none option is not implemented.)
Expand Down Expand Up @@ -213,13 +217,18 @@ def __init__(
ln_eps: Optional[float] = 1e-5,
smooth_type_embedding: bool = True,
type_one_side: bool = False,
stripped_type_embedding: Optional[bool] = None,
# not implemented
spin=None,
type: Optional[str] = None,
seed: Optional[int] = None,
old_impl: bool = False,
):
super().__init__()
# Ensure compatibility with the deprecated stripped_type_embedding option.
if stripped_type_embedding is not None:
# Use the user-set stripped_type_embedding parameter first
tebd_input_mode = "strip" if stripped_type_embedding else "concat"

Check warning on line 231 in deepmd/pt/model/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa1.py#L231

Added line #L231 was not covered by tests
if spin is not None:
raise NotImplementedError("old implementation of spin is not supported.")
if attn_mask:
Expand Down
13 changes: 11 additions & 2 deletions deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ class DescrptSeAtten(DescrptSeA):
And when using stripped type embedding, whether to dot smooth factor on the network output of type embedding
to keep the network smooth, instead of setting `set_davg_zero` to be True.
Default value will be True in `se_atten_v2` descriptor.
stripped_type_embedding: bool, Optional
(Deprecated, kept only for compatibility.)
Whether to strip the type embedding into a separated embedding network.
Setting this to `True` is equivalent to setting `tebd_input_mode` to 'strip'.
Raises
------
Expand Down Expand Up @@ -193,10 +197,15 @@ def __init__(
ln_eps: Optional[float] = 1e-3,
concat_output_tebd: bool = True,
env_protection: float = 0.0, # not implement!!
stripped_type_embedding: Optional[bool] = None,
**kwargs,
) -> None:
# Ensure compatibility with the deprecated `stripped_type_embedding` option.
stripped_type_embedding = tebd_input_mode == "strip"
# Ensure compatibility with the deprecated stripped_type_embedding option.
if stripped_type_embedding is None:
stripped_type_embedding = tebd_input_mode == "strip"
else:
# Use the user-set stripped_type_embedding parameter first
tebd_input_mode = "strip" if stripped_type_embedding else "concat"

Check warning on line 208 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L208

Added line #L208 was not covered by tests
if not set_davg_zero and not (
stripped_type_embedding and smooth_type_embedding
):
Expand Down
45 changes: 11 additions & 34 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,20 @@ def descrpt_se_atten_args():
f"When `type_one_side` is False, the input is `input_t = concat([tebd_j, tebd_i])`. {doc_only_pt_supported} When `type_one_side` is True, the input is `input_t = tebd_j`. "
"The output is `out_ij = embeding_t(input_t) * embeding_s(r_ij) + embeding_s(r_ij)` for the pair-wise representation of atom i with neighbor j."
)
doc_stripped_type_embedding = (
"(Deprecated, kept only for compatibility.) Whether to strip the type embedding into a separated embedding network. "
"Setting this to `True` is equivalent to setting `tebd_input_mode` to 'strip'."
)

return [
*descrpt_se_atten_common_args(),
Argument(
"stripped_type_embedding",
bool,
optional=True,
default=None,
doc=doc_stripped_type_embedding,
),
Argument(
"smooth_type_embedding",
bool,
Expand Down Expand Up @@ -2311,39 +2322,6 @@ def gen_args(**kwargs) -> List[Argument]:
]


def backend_compat(data):
data = data.copy()

def compat_stripped_type_embedding(descriptor_param):
# stripped_type_embedding in old DescrptSeAtten
descriptor_param = descriptor_param.copy()
if descriptor_param.get(
"type", "se_e2_a"
) == "se_atten" and descriptor_param.pop("stripped_type_embedding", False):
if "tebd_input_mode" not in descriptor_param:
descriptor_param["tebd_input_mode"] = "strip"
elif descriptor_param["tebd_input_mode"] != "strip":
raise ValueError(
"Conflict detected: 'stripped_type_embedding' is set to True, but 'tebd_input_mode' is not 'strip'. Please ensure 'tebd_input_mode' is set to 'strip' when 'stripped_type_embedding' is True."
)
else:
pass

return descriptor_param

if "descriptor" in data["model"]:
if "list" not in data["model"]["descriptor"]:
data["model"]["descriptor"] = compat_stripped_type_embedding(
data["model"]["descriptor"]
)
else:
for ii, descriptor in enumerate(data["model"]["descriptor"]["list"]):
data["model"]["descriptor"]["list"][ii] = (
compat_stripped_type_embedding(descriptor)
)
return data


def normalize_multi_task(data):
# single-task or multi-task mode
if data["model"].get("type", "standard") not in ("standard", "multi"):
Expand Down Expand Up @@ -2540,7 +2518,6 @@ def normalize_fitting_weight(fitting_keys, data_keys, fitting_weight=None):

def normalize(data):
data = normalize_multi_task(data)
data = backend_compat(data)

base = Argument("base", dict, gen_args())
data = base.normalize_value(data, trim_pattern="_*")
Expand Down

0 comments on commit 3a98a65

Please sign in to comment.