diff --git a/tests/forcefields/test_utils.py b/tests/forcefields/test_utils.py index 023482ea2..b43eb6ff0 100644 --- a/tests/forcefields/test_utils.py +++ b/tests/forcefields/test_utils.py @@ -4,6 +4,12 @@ from atomate2.forcefields.utils import ase_calculator +@pytest.mark.parametrize(("force_field"), [mlff.value for mlff in MLFF]) +def test_mlff(force_field: str): + mlff = MLFF(force_field) + assert mlff == MLFF(str(mlff)) == MLFF(str(mlff).split(".")[-1]) + + @pytest.mark.parametrize(("force_field"), ["CHGNet", "MACE"]) def test_ext_load(force_field: str): decode_dict = {