Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BaseLammpsCalculation: Add the optional settings input #64

Merged
merged 1 commit into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions aiida_lammps/calculations/lammps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ def define(cls, spec):
"parameters",
valid_type=orm.Dict,
required=False,
help="Parameters that control the ``LAMMPS`` calculation",
help="Parameters that control the input script generated for the ``LAMMPS`` calculation",
)
spec.input(
"settings",
valid_type=orm.Dict,
required=False,
validator=cls.validate_settings,
help="Additional settings that control the ``LAMMPS`` calculation",
)
spec.input(
"input_restartfile",
Expand Down Expand Up @@ -196,6 +203,23 @@ def validate_inputs(cls, value, ctx):
"`parameters` have to be specified."
)

@classmethod
def validate_settings(cls, value, ctx):
"""Validate the ``settings`` input."""
if not value:
return

settings = value.get_dict()
additional_cmdline_params = settings.get("additional_cmdline_params", [])

if not isinstance(additional_cmdline_params, list) or any(
not isinstance(e, str) for e in additional_cmdline_params
):
return (
"Invalid value for `additional_cmdline_params`, should be "
f"list of strings but got: {additional_cmdline_params}"
)

def prepare_for_submission(self, folder):
"""
Create the input files from the input nodes passed to this instance of the `CalcJob`.
Expand Down Expand Up @@ -268,10 +292,16 @@ def prepare_for_submission(self, folder):
with folder.open(_input_filename, "w") as handle:
handle.write(input_filecontent)

cmdline_params = ["-in", _input_filename, "-log", _logfile_filename]

if "settings" in self.inputs:
settings = self.inputs.settings.get_dict()
cmdline_params += settings.get("additional_cmdline_params", [])

codeinfo = datastructures.CodeInfo()
# Command line variables to ensure that the input file from LAMMPS can
# be read
codeinfo.cmdline_params = ["-in", _input_filename, "-log", _logfile_filename]
codeinfo.cmdline_params = cmdline_params
# Set the code uuid
codeinfo.code_uuid = self.inputs.code.uuid
# Set the name of the stdout
Expand Down
32 changes: 31 additions & 1 deletion tests/test_calculations.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,38 @@ def test_lammps_base_script(generate_calc_job, aiida_local_code_factory):
"""
)
stream = io.StringIO(content)
script = DataFactory("core.singlefile")(stream)
script = orm.SinglefileData(stream)

inputs["script"] = script
tmp_path, calc_info = generate_calc_job("lammps.base", inputs)
assert (tmp_path / BaseLammpsCalculation._INPUT_FILENAME).read_text() == content


def test_lammps_base_settings_invalid(generate_calc_job, aiida_local_code_factory):
"""Test the validation of the ``settings`` input."""
inputs = {
"code": aiida_local_code_factory("lammps.base", "bash"),
"settings": orm.Dict({"additional_cmdline_params": ["--option", 1]}),
"metadata": {"options": {"resources": {"num_machines": 1}}},
}

with pytest.raises(
ValueError,
match=r"Invalid value for `additional_cmdline_params`, should be list of strings but got.*",
):
generate_calc_job("lammps.base", inputs)


def test_lammps_base_settings(generate_calc_job, aiida_local_code_factory):
"""Test the ``BaseLammpsCalculation`` with the ``settings`` input."""
from aiida_lammps.calculations.lammps.base import BaseLammpsCalculation

inputs = {
"code": aiida_local_code_factory("lammps.base", "bash"),
"script": orm.SinglefileData(io.StringIO("")),
"settings": orm.Dict({"additional_cmdline_params": ["--option", "value"]}),
"metadata": {"options": {"resources": {"num_machines": 1}}},
}

_, calc_info = generate_calc_job("lammps.base", inputs)
assert calc_info.codes_info[0].cmdline_params[-2:] == ["--option", "value"]