diff --git a/README.md b/README.md index f3e00a8..34a6ce1 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,8 @@ print(results['parameters'].get_dict()) The geometry of the structure is fully defined through the `structure` input, which is provided by a `StructureData` node. Any other properties, e.g., the charge and what basis set to use, can be specified through the `structure` -dictionary in the `parameters` input: +dictionary in the `parameters` input. A specific SCF solver can also be specified using the `solver` keyword. For +example: ```python from ase.build import molecule @@ -142,7 +143,10 @@ from aiida.orm import Dict, StructureData, load_code builder = load_code('pyscf').get_builder() builder.structure = StructureData(ase=molecule('H2O')) builder.parameters = Dict({ - 'mean_field': {'method': 'RHF'}, + 'mean_field': { + 'method': 'RHF', + 'solver': 'CDIIS', + }, 'structure': { 'basis ': 'sto-3g', 'charge': 0, diff --git a/src/aiida_pyscf/calculations/base.py b/src/aiida_pyscf/calculations/base.py index caf8944..c6afe30 100644 --- a/src/aiida_pyscf/calculations/base.py +++ b/src/aiida_pyscf/calculations/base.py @@ -174,6 +174,7 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0 if (scf_solver := mean_field.get('solver')) is not None: valid_scf_solvers = ('DIIS', 'CDIIS', 'EDIIS', 'ADIIS') + forbidden_scf_solvers = ('SOSCF', 'NEWTON') options = ' '.join(valid_scf_solvers) if scf_solver is None: @@ -183,6 +184,13 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0 scf_solver = 'CDIIS' return '`DIIS` is an alias for CDIIS in PySCF. Using `CDIIS` explicitly instead.' + # When PySCF adds support for pickling SOSCF objects, we can remove this stanza + if scf_solver.upper() in forbidden_scf_solvers: + return ( + f'The solver `{scf_solver.upper()}` specified in `mean_field.solver` parameters is not yet ' + f'supported. Choose from: {options}' + ) + if scf_solver.upper() not in valid_scf_solvers: return ( f'Invalid solver `{scf_solver}` specified in `mean_field.solver` parameters. Choose from: {options}' diff --git a/tests/calculations/test_base.py b/tests/calculations/test_base.py index da98ffd..c80f608 100644 --- a/tests/calculations/test_base.py +++ b/tests/calculations/test_base.py @@ -210,6 +210,19 @@ def test_invalid_parameters_mean_field_solver_diis(generate_calc_job, generate_i with pytest.raises(ValueError, match=r'`DIIS` is an alias for CDIIS in PySCF. Using `CDIIS` explicitly instead.'): generate_calc_job(PyscfCalculation, inputs=inputs) +@pytest.mark.parametrize( + 'solver, expected', ( + ({'solver': 'newton'}, 'The solver `NEWTON` specified in `mean_field.solver` parameters is not yet supported.'), + ({'solver': 'sOsCf'}, 'The solver `SOSCF` specified in `mean_field.solver` parameters is not yet supported.'), + ) +) +def test_invalid_parameters_mean_field_solver_second_order(generate_calc_job, generate_inputs_pyscf, solver, expected): + """Test logic to catch second order solver input for ``parameters.mean_field.solver``.""" + parameters = {'mean_field': solver} + inputs = generate_inputs_pyscf(parameters=parameters) + with pytest.raises(ValueError, match=expected): + generate_calc_job(PyscfCalculation, inputs=inputs) + def test_invalid_parameters_mean_field_chkfile(generate_calc_job, generate_inputs_pyscf): """Test validation of ``parameters.mean_field.chkfile``, is not allowed as set automatically by plugin."""