Skip to content

Commit

Permalink
Merge pull request #59 from aiidaplugins/feature/base-input-file
Browse files Browse the repository at this point in the history
`BaseLammpsCalculation`: Add the `script` input
  • Loading branch information
JPchico authored Mar 22, 2023
2 parents d679df5 + fd8a395 commit 87db971
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 35 deletions.
91 changes: 56 additions & 35 deletions aiida_lammps/calculations/lammps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,28 @@ class BaseLammpsCalculation(CalcJob):
@classmethod
def define(cls, spec):
super().define(spec)
spec.input(
"script",
valid_type=orm.SinglefileData,
required=False,
help="Complete input script to use. If specified, `structure`, `potential` and `parameters` are ignored.",
)
spec.input(
"structure",
valid_type=orm.StructureData,
required=True,
required=False,
help="Structure used in the ``LAMMPS`` calculation",
)
spec.input(
"potential",
valid_type=LammpsPotentialData,
required=True,
required=False,
help="Potential used in the ``LAMMPS`` calculation",
)
spec.input(
"parameters",
valid_type=orm.Dict,
required=True,
required=False,
help="Parameters that control the ``LAMMPS`` calculation",
)
spec.input(
Expand Down Expand Up @@ -102,6 +108,7 @@ def define(cls, spec):
default=cls._DEFAULT_RESTART_FILENAME,
)
spec.inputs["metadata"]["options"]["parser_name"].default = cls._DEFAULT_PARSER
spec.inputs.validator = cls.validate_inputs

spec.output(
"results",
Expand Down Expand Up @@ -178,29 +185,23 @@ def define(cls, spec):
message="error parsing the final variable file has failed.",
)

@classmethod
def validate_inputs(cls, value, ctx):
"""Validate the top-level inputs namespace."""
if "script" not in value and any(
key not in value for key in ("structure", "potential", "parameters")
):
return (
"Unless `script` is specified the inputs `structure`, `potential` and "
"`parameters` have to be specified."
)

def prepare_for_submission(self, folder):
"""
Create the input files from the input nodes passed to this instance of the `CalcJob`.
"""
# pylint: disable=too-many-locals

# Generate the content of the structure file based on the input
# structure
structure_filecontent, _ = generate_lammps_structure(
self.inputs.structure,
self.inputs.potential.atom_style,
)

# Get the name of the structure file and write it to the remote folder
_structure_filename = self.inputs.metadata.options.structure_filename

with folder.open(_structure_filename, "w") as handle:
handle.write(structure_filecontent)

# Get the parameters dictionary so that they can be used for creating
# the input file
_parameters = self.inputs.parameters.get_dict()

# Get the name of the trajectory file
_trajectory_filename = self.inputs.metadata.options.trajectory_filename

Expand All @@ -225,28 +226,48 @@ def prepare_for_submission(self, folder):
else:
_read_restart_filename = None

# Write the input file content. This function will also check the
# sanity of the passed paremters when comparing it to a schema
input_filecontent = generate_input_file(
potential=self.inputs.potential,
structure=self.inputs.structure,
parameters=_parameters,
restart_filename=_restart_filename,
trajectory_filename=_trajectory_filename,
variables_filename=_variables_filename,
read_restart_filename=_read_restart_filename,
)
if "script" in self.inputs:
input_filecontent = self.inputs.script.get_content()
else:
# Get the parameters dictionary so that they can be used for creating
# the input file
_parameters = self.inputs.parameters.get_dict()

# Generate the content of the structure file based on the input
# structure
structure_filecontent, _ = generate_lammps_structure(
self.inputs.structure,
self.inputs.potential.atom_style,
)

# Get the name of the structure file and write it to the remote folder
_structure_filename = self.inputs.metadata.options.structure_filename

with folder.open(_structure_filename, "w") as handle:
handle.write(structure_filecontent)

# Write the potential to the remote folder
with folder.open(self._DEFAULT_POTENTIAL_FILENAME, "w") as handle:
handle.write(self.inputs.potential.get_content())

# Write the input file content. This function will also check the
# sanity of the passed paremters when comparing it to a schema
input_filecontent = generate_input_file(
potential=self.inputs.potential,
structure=self.inputs.structure,
parameters=_parameters,
restart_filename=_restart_filename,
trajectory_filename=_trajectory_filename,
variables_filename=_variables_filename,
read_restart_filename=_read_restart_filename,
)

# Get the name of the input file, and write it to the remote folder
_input_filename = self.inputs.metadata.options.input_filename

with folder.open(_input_filename, "w") as handle:
handle.write(input_filecontent)

# Write the potential to the remote folder
with folder.open(self._DEFAULT_POTENTIAL_FILENAME, "w") as handle:
handle.write(self.inputs.potential.get_content())

codeinfo = datastructures.CodeInfo()
# Command line variables to ensure that the input file from LAMMPS can
# be read
Expand Down
41 changes: 41 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""
initialise a test database and profile
"""
from __future__ import annotations

from collections import namedtuple
import os
import shutil
import tempfile
import typing as t

from aiida import orm
import numpy as np
Expand Down Expand Up @@ -80,6 +83,44 @@ def db_test_app(aiida_profile, pytestconfig):
shutil.rmtree(work_directory)


@pytest.fixture
def generate_calc_job(tmp_path):
"""Create a :class:`aiida.engine.CalcJob` instance with the given inputs.
The fixture will call ``prepare_for_submission`` and return a tuple of the temporary folder that was passed to it,
as well as the ``CalcInfo`` instance that it returned.
"""

def factory(
entry_point_name: str,
inputs: dict[str, t.Any] | None = None,
return_process: bool = False,
) -> tuple[pathlib.Path, CalcInfo] | CalcJob:
"""Create a :class:`aiida.engine.CalcJob` instance with the given inputs.
:param entry_point_name: The entry point name of the calculation job plugin to run.
:param inputs: The dictionary of inputs for the calculation job.
:param return_process: Flag, if ``True``, return the constructed ``CalcJob`` instance instead of the tuple of
the temporary folder and ``CalcInfo`` instance.
"""
from aiida.common.folders import Folder
from aiida.engine.utils import instantiate_process
from aiida.manage import get_manager
from aiida.plugins import CalculationFactory

runner = get_manager().get_runner()
process_class = CalculationFactory(entry_point_name)
process = instantiate_process(runner, process_class, **inputs or {})
calc_info = process.prepare_for_submission(Folder(tmp_path))

if return_process:
return process

return tmp_path, calc_info

return factory


@pytest.fixture(scope="function")
def get_structure_data():
"""get the structure data for the simulation."""
Expand Down
37 changes: 37 additions & 0 deletions tests/test_calculations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Test the aiida-lammps calculations."""
import io
import textwrap

from aiida import orm
from aiida.cmdline.utils.common import get_calcjob_report
from aiida.common import AttributeDict
Expand Down Expand Up @@ -672,3 +675,37 @@ def test_lammps_base(
), _msg
else:
assert sub_value == _step_data[sub_key], _msg


def test_lammps_base_script(generate_calc_job, aiida_local_code_factory):
"""Test the ``BaseLammpsCalculation`` with the ``script`` input."""
from aiida_lammps.calculations.lammps.base import BaseLammpsCalculation

inputs = {
"code": aiida_local_code_factory("lammps.base", "bash"),
"metadata": {"options": {"resources": {"num_machines": 1}}},
}

with pytest.raises(
ValueError,
match=r"Unless `script` is specified the inputs .* have to be specified.",
):
generate_calc_job("lammps.base", inputs)

content = textwrap.dedent(
"""
"velocity all create 1.44 87287 loop geom
"pair_style lj/cut 2.5
"pair_coeff 1 1 1.0 1.0 2.5
"neighbor 0.3 bin
"neigh_modify delay 0 every 20 check no
"fix 1 all nve
"run 10000
"""
)
stream = io.StringIO(content)
script = DataFactory("core.singlefile")(stream)

inputs["script"] = script
tmp_path, calc_info = generate_calc_job("lammps.base", inputs)
assert (tmp_path / BaseLammpsCalculation._INPUT_FILENAME).read_text() == content

0 comments on commit 87db971

Please sign in to comment.