Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat (tf/pt): add atomic weights to tensor loss #4466

Merged
merged 6 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
pref_atomic: float = 0.0,
pref: float = 0.0,
inference=False,
enable_atomic_weight: bool = False,
**kwargs,
) -> None:
r"""Construct a loss for local and global tensors.
Expand All @@ -40,6 +41,8 @@
The prefactor of the weight of global loss. It should be larger than or equal to 0.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
enable_atomic_weight : bool
If true, atomic weight will be used in the loss calculation.
**kwargs
Other keyword arguments.
"""
Expand All @@ -50,6 +53,7 @@
self.local_weight = pref_atomic
self.global_weight = pref
self.inference = inference
self.enable_atomic_weight = enable_atomic_weight

assert (
self.local_weight >= 0.0 and self.global_weight >= 0.0
Expand Down Expand Up @@ -85,6 +89,12 @@
"""
model_pred = model(**input_dict)
del learning_rate, mae

if self.enable_atomic_weight:
atomic_weight = label["atom_weight"].reshape([-1, 1])
else:
atomic_weight = 1.0

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if (
Expand All @@ -103,6 +113,7 @@
diff = (local_tensor_pred - local_tensor_label).reshape(
[-1, self.tensor_size]
)
diff = diff * atomic_weight
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss = torch.mean(torch.square(diff))
Expand Down Expand Up @@ -171,4 +182,15 @@
high_prec=False,
)
)
if self.enable_atomic_weight:
label_requirement.append(

Check warning on line 186 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L186

Added line #L186 was not covered by tests
DataRequirementItem(
"atomic_weight",
ndof=1,
atomic=True,
must=False,
high_prec=False,
default=1.0,
)
)
return label_requirement
24 changes: 23 additions & 1 deletion deepmd/tf/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
# YWolfeee: modify, use pref / pref_atomic, instead of pref_weight / pref_atomic_weight
self.local_weight = jdata.get("pref_atomic", None)
self.global_weight = jdata.get("pref", None)
self.enable_atomic_weight = jdata.get("enable_atomic_weight", False)

assert (
self.local_weight is not None and self.global_weight is not None
Expand All @@ -66,9 +67,18 @@
"global_loss": global_cvt_2_tf_float(0.0),
}

if self.enable_atomic_weight:
atomic_weight = tf.reshape(label_dict["atom_weight"], [-1, 1])
else:
atomic_weight = global_cvt_2_tf_float(1.0)

if self.local_weight > 0.0:
diff = tf.reshape(polar, [-1, self.tensor_size]) - tf.reshape(
atomic_polar_hat, [-1, self.tensor_size]
)
diff = diff * atomic_weight
local_loss = global_cvt_2_tf_float(find_atomic) * tf.reduce_mean(
tf.square(self.scale * (polar - atomic_polar_hat)), name="l2_" + suffix
tf.square(self.scale * diff), name="l2_" + suffix
)
more_loss["local_loss"] = self.display_if_exist(local_loss, find_atomic)
l2_loss += self.local_weight * local_loss
Expand Down Expand Up @@ -163,4 +173,16 @@
type_sel=self.type_sel,
)
)
if self.enable_atomic_weight:
data_requirements.append(

Check warning on line 177 in deepmd/tf/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/loss/tensor.py#L177

Added line #L177 was not covered by tests
DataRequirementItem(
"atom_weight",
1,
atomic=True,
must=False,
high_prec=False,
default=1.0,
type_sel=self.type_sel,
)
)
return data_requirements
12 changes: 10 additions & 2 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2511,8 +2511,9 @@ def loss_property():
def loss_tensor():
# doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If only `pref` is provided or both are not provided, training will be global mode, i.e. the shape of 'polarizability.npy` or `dipole.npy` should be #frams x [9 or 3]."
# doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If only `pref_atomic` is provided, training will be atomic mode, i.e. the shape of `polarizability.npy` or `dipole.npy` should be #frames x ([9 or 3] x #selected atoms). If both `pref` and `pref_atomic` are provided, training will be combined mode, and atomic label should be provided as well."
doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included."
doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #selected atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0."
doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included."
doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0."
doc_enable_atomic_weight = "If true, the atomic loss will be reweighted."
return [
Argument(
"pref", [float, int], optional=False, default=None, doc=doc_global_weight
Expand All @@ -2524,6 +2525,13 @@ def loss_tensor():
default=None,
doc=doc_local_weight,
),
Argument(
"enable_atomic_weight",
bool,
optional=True,
default=False,
doc=doc_enable_atomic_weight,
),
]


Expand Down
Loading
Loading