diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 9e69b6841e..1c81d42013 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -151,6 +151,15 @@ def get_zbl_model(model_params): ) +def _can_be_converted_to_float(value): + try: + float(value) + return True + except (TypeError, ValueError): + # return false for any failure... + return False + + def _convert_preset_out_bias_to_array(preset_out_bias, type_map): if preset_out_bias is not None: for kk in preset_out_bias: @@ -160,7 +169,17 @@ def _convert_preset_out_bias_to_array(preset_out_bias, type_map): ) for jj in range(len(preset_out_bias[kk])): if preset_out_bias[kk][jj] is not None: - preset_out_bias[kk][jj] = np.array(preset_out_bias[kk][jj]) + if isinstance(preset_out_bias[kk][jj], list): + bb = preset_out_bias[kk][jj] + elif _can_be_converted_to_float(preset_out_bias[kk][jj]): + bb = [float(preset_out_bias[kk][jj])] + else: + raise ValueError( + f"unsupported type/value of the {jj}th element of " + f"preset_out_bias['{kk}'] " + f"{type(preset_out_bias[kk][jj])}" + ) + preset_out_bias[kk][jj] = np.array(bb) return preset_out_bias diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 4eab9d87df..a799b6b0c4 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1772,7 +1772,7 @@ def model_args(exclude_hybrid=False): doc_spin = "The settings for systems with spin." doc_atom_exclude_types = "Exclude the atomic contribution of the listed atom types" doc_pair_exclude_types = "The atom pairs of the listed types are not treated to be neighbors, i.e. they do not see each other." - doc_preset_out_bias = "The preset bias of the atomic output. Is provided as a dict. Taking the energy model that has three atom types for example, the preset_out_bias may be given as `{ 'energy': [null, 0., 1.] }`. In this case the bias of type 1 and 2 are set to 0. and 1., respectively.The set_davg_zero should be set to true." + doc_preset_out_bias = "The preset bias of the atomic output. Note that the set_davg_zero should be set to true. The bias is provided as a dict. Taking the energy model that has three atom types for example, the `preset_out_bias` may be given as `{ 'energy': [null, 0., 1.] }`. In this case the energy bias of type 1 and 2 are set to 0. and 1., respectively. A dipole model with two atom types may set `preset_out_bias` as `{ 'dipole': [null, [0., 1., 2.]] }`" doc_finetune_head = ( "The chosen fitting net to fine-tune on, when doing multi-task fine-tuning. " "If not set or set to 'RANDOM', the fitting net will be randomly initialized." @@ -1837,7 +1837,7 @@ def model_args(exclude_hybrid=False): ), Argument( "preset_out_bias", - Dict[str, Optional[float]], + Dict[str, List[Optional[Union[float, List[float]]]]], optional=True, default=None, doc=doc_only_pt_supported + doc_preset_out_bias, diff --git a/source/tests/pt/model/test_get_model.py b/source/tests/pt/model/test_get_model.py index c433597d5a..4774582f57 100644 --- a/source/tests/pt/model/test_get_model.py +++ b/source/tests/pt/model/test_get_model.py @@ -63,6 +63,37 @@ def test_model_attr(self): self.assertEqual(atomic_model.atom_exclude_types, [1]) self.assertEqual(atomic_model.pair_exclude_types, [[1, 2]]) + def test_model_attr_energy_float(self): + model_params = copy.deepcopy(model_se_e2_a) + model_params["preset_out_bias"] = {"energy": ["1.", 3, None]} + self.model = get_model(model_params).to(env.DEVICE) + atomic_model = self.model.atomic_model + self.assertEqual(atomic_model.type_map, ["O", "H", "B"]) + self.assertEqual( + atomic_model.preset_out_bias, + { + "energy": [ + np.array([1.0]), + np.array([3.0]), + None, + ] + }, + ) + self.assertEqual(atomic_model.atom_exclude_types, [1]) + self.assertEqual(atomic_model.pair_exclude_types, [[1, 2]]) + + def test_model_attr_energy_unsupported_type(self): + model_params = copy.deepcopy(model_se_e2_a) + model_params["preset_out_bias"] = {"energy": [1.0 + 2.0j, 3, None]} + with self.assertRaises(ValueError): + self.model = get_model(model_params).to(env.DEVICE) + + def test_model_attr_energy_unsupported_value(self): + model_params = copy.deepcopy(model_se_e2_a) + model_params["preset_out_bias"] = {"energy": ["1.0 + 2.0j", 3, None]} + with self.assertRaises(ValueError): + self.model = get_model(model_params).to(env.DEVICE) + def test_notset_model_attr(self): model_params = copy.deepcopy(model_se_e2_a) model_params.pop("atom_exclude_types")