Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): support use_aparam_as_mask for pt backend #4246

Merged
merged 13 commits into from
Oct 26, 2024
8 changes: 5 additions & 3 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,15 @@ def __init__(
self.fparam_inv_std = np.ones(self.numb_fparam) # pylint: disable=no-explicit-dtype
else:
self.fparam_avg, self.fparam_inv_std = None, None
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
self.aparam_avg = np.zeros(self.numb_aparam) # pylint: disable=no-explicit-dtype
self.aparam_inv_std = np.ones(self.numb_aparam) # pylint: disable=no-explicit-dtype
else:
self.aparam_avg, self.aparam_inv_std = None, None
# init networks
in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam
in_dim = self.dim_descrpt + self.numb_fparam
if not self.use_aparam_as_mask:
in_dim += self.numb_aparam
self.nets = NetworkCollection(
1 if not self.mixed_types else 0,
self.ntypes,
Expand Down Expand Up @@ -389,7 +391,7 @@ def _call_common(
axis=-1,
)
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0:
if not self.use_aparam_as_mask and self.numb_aparam > 0:
assert aparam is not None, "aparam should not be None"
if aparam.shape[-1] != self.numb_aparam:
raise ValueError(
Expand Down
4 changes: 0 additions & 4 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ def __init__(
raise NotImplementedError("tot_ener_zero is not implemented")
if spin is not None:
raise NotImplementedError("spin is not implemented")
if use_aparam_as_mask:
raise NotImplementedError("use_aparam_as_mask is not implemented")
if use_aparam_as_mask:
raise NotImplementedError("use_aparam_as_mask is not implemented")
if layer_name is not None:
raise NotImplementedError("layer_name is not implemented")

Expand Down
15 changes: 10 additions & 5 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class GeneralFitting(Fitting):
length as `ntypes` signaling if or not removing the vaccum contribution for the atom types in the list.
type_map: list[str], Optional
A list of strings. Give the name to each type of atoms.
use_aparam_as_mask: bool
If True, the aparam will not be used in fitting net for embedding.
"""

def __init__(
Expand All @@ -147,6 +149,7 @@ def __init__(
trainable: Union[bool, list[bool]] = True,
remove_vaccum_contribution: Optional[list[bool]] = None,
type_map: Optional[list[str]] = None,
use_aparam_as_mask: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -164,6 +167,7 @@ def __init__(
self.rcond = rcond
self.seed = seed
self.type_map = type_map
self.use_aparam_as_mask = use_aparam_as_mask
# order matters, should be place after the assignment of ntypes
self.reinit_exclude(exclude_types)
self.trainable = trainable
Expand Down Expand Up @@ -194,7 +198,7 @@ def __init__(
)
else:
self.fparam_avg, self.fparam_inv_std = None, None
if self.numb_aparam > 0:
if not self.use_aparam_as_mask and self.numb_aparam > 0:
self.register_buffer(
"aparam_avg",
torch.zeros(self.numb_aparam, dtype=self.prec, device=device),
Expand All @@ -206,7 +210,9 @@ def __init__(
else:
self.aparam_avg, self.aparam_inv_std = None, None

in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam
in_dim = self.dim_descrpt + self.numb_fparam
if not self.use_aparam_as_mask:
in_dim += self.numb_aparam

self.filter_layers = NetworkCollection(
1 if not self.mixed_types else 0,
Expand Down Expand Up @@ -291,13 +297,12 @@ def serialize(self) -> dict:
# "trainable": self.trainable ,
# "atom_ener": self.atom_ener ,
# "layer_name": self.layer_name ,
# "use_aparam_as_mask": self.use_aparam_as_mask ,
# "spin": self.spin ,
## NOTICE: not supported by far
"tot_ener_zero": False,
"trainable": [self.trainable] * (len(self.neuron) + 1),
"layer_name": None,
"use_aparam_as_mask": False,
"use_aparam_as_mask": self.use_aparam_as_mask,
"spin": None,
}

Expand Down Expand Up @@ -439,7 +444,7 @@ def _forward_common(
dim=-1,
)
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0:
if not self.use_aparam_as_mask and self.numb_aparam > 0:
assert aparam is not None, "aparam should not be None"
assert self.aparam_avg is not None
assert self.aparam_inv_std is not None
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class InvarFitting(GeneralFitting):
The `set_davg_zero` key in the descrptor should be set.
type_map: list[str], Optional
A list of strings. Give the name to each type of atoms.

use_aparam_as_mask: bool
If True, the aparam will not be used in fitting net for embedding.
"""

def __init__(
Expand All @@ -99,6 +100,7 @@ def __init__(
exclude_types: list[int] = [],
atom_ener: Optional[list[Optional[torch.Tensor]]] = None,
type_map: Optional[list[str]] = None,
use_aparam_as_mask: bool = False,
**kwargs,
):
self.dim_out = dim_out
Expand All @@ -122,6 +124,7 @@ def __init__(
if atom_ener is None or len([x for x in atom_ener if x is not None]) == 0
else [x is not None for x in atom_ener],
type_map=type_map,
use_aparam_as_mask=use_aparam_as_mask,
**kwargs,
)

Expand Down
28 changes: 15 additions & 13 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@
self.fparam_std[ii] = protection
self.fparam_inv_std = 1.0 / self.fparam_std
# stat aparam
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
sys_sumv = []
sys_sumv2 = []
sys_sumn = []
Expand Down Expand Up @@ -384,7 +384,7 @@
ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam])
ext_fparam = tf.cast(ext_fparam, self.fitting_precision)
layer = tf.concat([layer, ext_fparam], axis=1)
if aparam is not None:
if aparam is not None and not self.use_aparam_as_mask:
ext_aparam = tf.slice(
aparam,
[0, start_index * self.numb_aparam],
Expand Down Expand Up @@ -505,7 +505,7 @@
self.fparam_avg = 0.0
if self.fparam_inv_std is None:
self.fparam_inv_std = 1.0
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
if self.aparam_avg is None:
self.aparam_avg = 0.0
if self.aparam_inv_std is None:
Expand Down Expand Up @@ -561,7 +561,7 @@
trainable=False,
initializer=tf.constant_initializer(self.fparam_inv_std),
)
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
t_aparam_avg = tf.get_variable(
"t_aparam_avg",
self.numb_aparam,
Expand Down Expand Up @@ -602,12 +602,11 @@
fparam = (fparam - t_fparam_avg) * t_fparam_istd

aparam = None
if not self.use_aparam_as_mask:
if self.numb_aparam > 0:
aparam = input_dict["aparam"]
aparam = tf.reshape(aparam, [-1, self.numb_aparam])
aparam = (aparam - t_aparam_avg) * t_aparam_istd
aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]])
if not self.use_aparam_as_mask and self.numb_aparam > 0:
aparam = input_dict["aparam"]
aparam = tf.reshape(aparam, [-1, self.numb_aparam])
aparam = (aparam - t_aparam_avg) * t_aparam_istd
aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]])

atype_nall = tf.reshape(atype, [-1, natoms[1]])
self.atype_nloc = tf.slice(
Expand Down Expand Up @@ -783,7 +782,7 @@
self.fparam_inv_std = get_tensor_by_name_from_graph(
graph, f"fitting_attr{suffix}/t_fparam_istd"
)
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
self.aparam_avg = get_tensor_by_name_from_graph(
graph, f"fitting_attr{suffix}/t_aparam_avg"
)
Expand Down Expand Up @@ -883,7 +882,7 @@
if fitting.numb_fparam > 0:
fitting.fparam_avg = data["@variables"]["fparam_avg"]
fitting.fparam_inv_std = data["@variables"]["fparam_inv_std"]
if fitting.numb_aparam > 0:
if fitting.numb_aparam > 0 and not fitting.use_aparam_as_mask:
fitting.aparam_avg = data["@variables"]["aparam_avg"]
fitting.aparam_inv_std = data["@variables"]["aparam_inv_std"]
return fitting
Expand All @@ -896,6 +895,9 @@
dict
The serialized data
"""
in_dim = self.dim_descrpt + self.numb_fparam
if not self.use_aparam_as_mask:
in_dim += self.numb_aparam
data = {
"@class": "Fitting",
"type": "ener",
Expand All @@ -922,7 +924,7 @@
"nets": self.serialize_network(
ntypes=self.ntypes,
ndim=0 if self.mixed_types else 1,
in_dim=self.dim_descrpt + self.numb_fparam + self.numb_aparam,
in_dim=in_dim,
neuron=self.n_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.resnet_dt,
Expand Down
8 changes: 7 additions & 1 deletion source/tests/consistent/fitting/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class FittingTest:
"""Useful utilities for descriptor tests."""

def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, suffix):
def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, aparam, suffix):
t_inputs = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_inputs")
t_natoms = tf.placeholder(tf.int32, natoms.shape, name="i_natoms")
t_atype = tf.placeholder(tf.int32, [None], name="i_atype")
Expand All @@ -30,6 +30,12 @@ def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, suffix):
)
extras["fparam"] = t_fparam
feed_dict[t_fparam] = fparam
if aparam is not None:
t_aparam = tf.placeholder(
GLOBAL_TF_FLOAT_PRECISION, [None, None], name="i_aparam"
)
extras["aparam"] = t_aparam
feed_dict[t_aparam] = aparam
t_out = obj.build(
t_inputs,
t_natoms,
Expand Down
22 changes: 22 additions & 0 deletions source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
("float64", "float32"), # precision
(True, False), # mixed_types
(0, 1), # numb_fparam
(0, 1), # numb_aparam
(10, 20), # numb_dos
)
class TestDOS(CommonTest, FittingTest, unittest.TestCase):
Expand All @@ -68,13 +69,15 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return {
"neuron": [5, 5, 5],
"resnet_dt": resnet_dt,
"precision": precision,
"numb_fparam": numb_fparam,
"numb_aparam": numb_aparam,
"seed": 20240217,
"numb_dos": numb_dos,
}
Expand All @@ -86,6 +89,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable numb_aparam is not used.
numb_dos,
) = self.param
return CommonTest.skip_pt
Expand Down Expand Up @@ -115,6 +119,9 @@
# inconsistent if not sorted
self.atype.sort()
self.fparam = -np.ones((1,), dtype=GLOBAL_NP_FLOAT_PRECISION)
self.aparam = np.zeros_like(
self.atype, dtype=GLOBAL_NP_FLOAT_PRECISION
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
Expand All @@ -123,6 +130,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return {
Expand All @@ -137,6 +145,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return self.build_tf_fitting(
Expand All @@ -145,6 +154,7 @@
self.natoms,
self.atype,
self.fparam if numb_fparam else None,
self.aparam if numb_aparam else None,
suffix,
)

Expand All @@ -154,6 +164,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return (
Expand All @@ -163,6 +174,9 @@
fparam=torch.from_numpy(self.fparam).to(device=PT_DEVICE)
if numb_fparam
else None,
aparam=torch.from_numpy(self.aparam).to(device=PT_DEVICE)
if numb_aparam
else None,
)["dos"]
.detach()
.cpu()
Expand All @@ -175,12 +189,14 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return dp_obj(
self.inputs,
self.atype.reshape(1, -1),
fparam=self.fparam if numb_fparam else None,
aparam=self.aparam if numb_aparam else None,
)["dos"]

def eval_jax(self, jax_obj: Any) -> Any:
Expand All @@ -189,13 +205,15 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return np.asarray(
jax_obj(
jnp.asarray(self.inputs),
jnp.asarray(self.atype.reshape(1, -1)),
fparam=jnp.asarray(self.fparam) if numb_fparam else None,
aparam=jnp.asarray(self.aparam) if numb_aparam else None,
)["dos"]
)

Expand All @@ -206,13 +224,15 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return np.asarray(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None,
aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None,
)["dos"]
)

Expand All @@ -230,6 +250,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
if precision == "float64":
Expand All @@ -247,6 +268,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
if precision == "float64":
Expand Down
Loading
Loading