diff --git a/atomistics/calculators/lammps/calculator.py b/atomistics/calculators/lammps/calculator.py index a0d3b928..6fee3952 100644 --- a/atomistics/calculators/lammps/calculator.py +++ b/atomistics/calculators/lammps/calculator.py @@ -25,11 +25,7 @@ LAMMPS_RUN, LAMMPS_MINIMIZE_VOLUME, ) -from atomistics.calculators.lammps.output import ( - get_static_output, - quantities_md, - quantities_static, -) +from atomistics.calculators.lammps.output import LammpsMDOutput, LammpsStaticOutput if TYPE_CHECKING: from ase import Atoms @@ -117,7 +113,7 @@ def calc_static_with_lammps( structure, potential_dataframe, lmp=None, - quantities=quantities_static, + quantities=LammpsStaticOutput.fields(), **kwargs, ): template_str = LAMMPS_THERMO_STYLE + "\n" + LAMMPS_THERMO + "\n" + LAMMPS_RUN @@ -131,10 +127,7 @@ def calc_static_with_lammps( lmp=lmp, **kwargs, ) - result_dict = get_static_output( - lmp_instance=lmp_instance, - quantities=quantities, - ) + result_dict = LammpsStaticOutput.get(lmp_instance, *quantities) lammps_shutdown(lmp_instance=lmp_instance, close_instance=lmp is None) return result_dict @@ -151,7 +144,7 @@ def calc_molecular_dynamics_nvt_with_lammps( seed=4928459, dist="gaussian", lmp=None, - quantities=quantities_md, + quantities=LammpsMDOutput.fields(), **kwargs, ): init_str = ( @@ -208,7 +201,7 @@ def calc_molecular_dynamics_npt_with_lammps( seed=4928459, dist="gaussian", lmp=None, - quantities=quantities_md, + quantities=LammpsMDOutput.fields(), **kwargs, ): init_str = ( @@ -266,7 +259,7 @@ def calc_molecular_dynamics_nph_with_lammps( seed=4928459, dist="gaussian", lmp=None, - quantities=quantities_md, + quantities=LammpsMDOutput.fields(), **kwargs, ): init_str = ( diff --git a/atomistics/calculators/lammps/helpers.py b/atomistics/calculators/lammps/helpers.py index 790bd71e..0e763f71 100644 --- a/atomistics/calculators/lammps/helpers.py +++ b/atomistics/calculators/lammps/helpers.py @@ -5,7 +5,7 @@ from pylammpsmpi import LammpsASELibrary from atomistics.calculators.lammps.potential import validate_potential_dataframe -from atomistics.calculators.lammps.output import get_md_output, quantities_md +from atomistics.calculators.lammps.output import LammpsMDOutput def lammps_run(structure, potential_dataframe, input_template=None, lmp=None, **kwargs): @@ -41,14 +41,11 @@ def lammps_calc_md_step( lmp_instance, run_str, run, - quantities=quantities_md, + quantities=LammpsMDOutput.fields(), ): run_str_rendered = Template(run_str).render(run=run) lmp_instance.interactive_lib_command(run_str_rendered) - return get_md_output( - lmp_instance=lmp_instance, - quantities=quantities, - ) + return LammpsMDOutput.get(lmp_instance, *quantities) def lammps_calc_md( @@ -56,7 +53,7 @@ def lammps_calc_md( run_str, run, thermo, - quantities=quantities_md, + quantities=LammpsMDOutput.fields(), ): results_lst = [ lammps_calc_md_step( diff --git a/atomistics/calculators/lammps/output.py b/atomistics/calculators/lammps/output.py index 52328afb..3313bb1f 100644 --- a/atomistics/calculators/lammps/output.py +++ b/atomistics/calculators/lammps/output.py @@ -9,12 +9,13 @@ class LammpsOutput: def fields(cls): return tuple(field.name for field in dataclasses.fields(cls)) - def __call__(self, engine: LammpsASELibrary, quantity: str): - return getattr(self, quantity)(engine) + @classmethod + def get(cls, engine: LammpsASELibrary, *quantities: str) -> dict: + return {q: getattr(cls, q)(engine) for q in quantities} @dataclasses.dataclass -class LammpsMDQuantityGetter(LammpsOutput): +class LammpsMDOutput(LammpsOutput): positions: callable = LammpsASELibrary.interactive_positions_getter cell: callable = LammpsASELibrary.interactive_cells_getter forces: callable = LammpsASELibrary.interactive_forces_getter @@ -26,33 +27,7 @@ class LammpsMDQuantityGetter(LammpsOutput): @dataclasses.dataclass -class LammpsStaticQuantityGetter(LammpsOutput): +class LammpsStaticOutput(LammpsOutput): forces: callable = LammpsASELibrary.interactive_forces_getter energy: callable = LammpsASELibrary.interactive_energy_pot_getter stress: callable = LammpsASELibrary.interactive_pressures_getter - - -quantity_getter_md = LammpsMDQuantityGetter() -quantities_md = quantity_getter_md.fields() -quantity_getter_static = LammpsStaticQuantityGetter() -quantities_static = quantity_getter_static.fields() - - -def get_quantity(lmp_instance, quantity_getter, quantities): - return {q: quantity_getter(lmp_instance, q) for q in quantities} - - -def get_static_output(lmp_instance, quantities=quantities_static): - return get_quantity( - lmp_instance=lmp_instance, - quantity_getter=quantity_getter_static, - quantities=quantities, - ) - - -def get_md_output(lmp_instance, quantities=quantities_md): - return get_quantity( - lmp_instance=lmp_instance, - quantity_getter=quantity_getter_md, - quantities=quantities, - )