Skip to content

Commit

Permalink
fix: type of the preset out bias (#4135)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced error handling for the `preset_out_bias` parameter, allowing
for better validation of input types.
- Expanded documentation for `preset_out_bias`, providing clearer
guidelines and examples for users.

- **Bug Fixes**
- Improved robustness by ensuring unsupported types for energy values
raise appropriate errors.

- **Tests**
- Added new tests to validate the handling of various input types for
the `preset_out_bias` parameter, ensuring correct processing and error
reporting.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Sep 18, 2024
1 parent 64e6e52 commit 2e3b251
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
21 changes: 20 additions & 1 deletion deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ def get_zbl_model(model_params):
)


def _can_be_converted_to_float(value):
try:
float(value)
return True
except (TypeError, ValueError):
# return false for any failure...
return False


def _convert_preset_out_bias_to_array(preset_out_bias, type_map):
if preset_out_bias is not None:
for kk in preset_out_bias:
Expand All @@ -160,7 +169,17 @@ def _convert_preset_out_bias_to_array(preset_out_bias, type_map):
)
for jj in range(len(preset_out_bias[kk])):
if preset_out_bias[kk][jj] is not None:
preset_out_bias[kk][jj] = np.array(preset_out_bias[kk][jj])
if isinstance(preset_out_bias[kk][jj], list):
bb = preset_out_bias[kk][jj]
elif _can_be_converted_to_float(preset_out_bias[kk][jj]):
bb = [float(preset_out_bias[kk][jj])]
else:
raise ValueError(
f"unsupported type/value of the {jj}th element of "
f"preset_out_bias['{kk}'] "
f"{type(preset_out_bias[kk][jj])}"
)
preset_out_bias[kk][jj] = np.array(bb)
return preset_out_bias


Expand Down
4 changes: 2 additions & 2 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,7 +1772,7 @@ def model_args(exclude_hybrid=False):
doc_spin = "The settings for systems with spin."
doc_atom_exclude_types = "Exclude the atomic contribution of the listed atom types"
doc_pair_exclude_types = "The atom pairs of the listed types are not treated to be neighbors, i.e. they do not see each other."
doc_preset_out_bias = "The preset bias of the atomic output. Is provided as a dict. Taking the energy model that has three atom types for example, the preset_out_bias may be given as `{ 'energy': [null, 0., 1.] }`. In this case the bias of type 1 and 2 are set to 0. and 1., respectively.The set_davg_zero should be set to true."
doc_preset_out_bias = "The preset bias of the atomic output. Note that the set_davg_zero should be set to true. The bias is provided as a dict. Taking the energy model that has three atom types for example, the `preset_out_bias` may be given as `{ 'energy': [null, 0., 1.] }`. In this case the energy bias of type 1 and 2 are set to 0. and 1., respectively. A dipole model with two atom types may set `preset_out_bias` as `{ 'dipole': [null, [0., 1., 2.]] }`"
doc_finetune_head = (
"The chosen fitting net to fine-tune on, when doing multi-task fine-tuning. "
"If not set or set to 'RANDOM', the fitting net will be randomly initialized."
Expand Down Expand Up @@ -1837,7 +1837,7 @@ def model_args(exclude_hybrid=False):
),
Argument(
"preset_out_bias",
Dict[str, Optional[float]],
Dict[str, List[Optional[Union[float, List[float]]]]],
optional=True,
default=None,
doc=doc_only_pt_supported + doc_preset_out_bias,
Expand Down
31 changes: 31 additions & 0 deletions source/tests/pt/model/test_get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,37 @@ def test_model_attr(self):
self.assertEqual(atomic_model.atom_exclude_types, [1])
self.assertEqual(atomic_model.pair_exclude_types, [[1, 2]])

def test_model_attr_energy_float(self):
model_params = copy.deepcopy(model_se_e2_a)
model_params["preset_out_bias"] = {"energy": ["1.", 3, None]}
self.model = get_model(model_params).to(env.DEVICE)
atomic_model = self.model.atomic_model
self.assertEqual(atomic_model.type_map, ["O", "H", "B"])
self.assertEqual(
atomic_model.preset_out_bias,
{
"energy": [
np.array([1.0]),
np.array([3.0]),
None,
]
},
)
self.assertEqual(atomic_model.atom_exclude_types, [1])
self.assertEqual(atomic_model.pair_exclude_types, [[1, 2]])

def test_model_attr_energy_unsupported_type(self):
model_params = copy.deepcopy(model_se_e2_a)
model_params["preset_out_bias"] = {"energy": [1.0 + 2.0j, 3, None]}
with self.assertRaises(ValueError):
self.model = get_model(model_params).to(env.DEVICE)

def test_model_attr_energy_unsupported_value(self):
model_params = copy.deepcopy(model_se_e2_a)
model_params["preset_out_bias"] = {"energy": ["1.0 + 2.0j", 3, None]}
with self.assertRaises(ValueError):
self.model = get_model(model_params).to(env.DEVICE)

def test_notset_model_attr(self):
model_params = copy.deepcopy(model_se_e2_a)
model_params.pop("atom_exclude_types")
Expand Down

0 comments on commit 2e3b251

Please sign in to comment.