Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Dataclass for ASE and QE #131

Merged
merged 15 commits into from
Dec 13, 2023
20 changes: 10 additions & 10 deletions atomistics/calculators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from atomistics.calculators.ase import (
calc_energy_and_forces_with_ase,
calc_energy_with_ase,
calc_forces_with_ase,
calc_static_with_ase,
evaluate_with_ase,
optimize_positions_with_ase,
optimize_positions_and_volume_with_ase,
)
from atomistics.calculators.qe import (
calc_energy_and_forces_with_qe,
calc_energy_with_qe,
calc_forces_with_qe,
evaluate_with_qe,
optimize_positions_and_volume_with_qe,
)

try:
from atomistics.calculators.qe import (
calc_static_with_qe,
evaluate_with_qe,
optimize_positions_and_volume_with_qe,
)
except ImportError:
pass

try:
from atomistics.calculators.lammps import (
Expand Down
118 changes: 42 additions & 76 deletions atomistics/calculators/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from ase.constraints import UnitCellFilter
from typing import TYPE_CHECKING

from atomistics.calculators.interface import get_quantities_from_tasks
from atomistics.calculators.output import OutputStatic
from atomistics.calculators.wrapper import as_task_dict_evaluator

if TYPE_CHECKING:
Expand All @@ -12,6 +14,28 @@
from atomistics.calculators.interface import TaskName


class ASEExecutor(object):
def __init__(self, ase_structure, ase_calculator):
self.structure = ase_structure
self.structure.calc = ase_calculator

def get_forces(self):
return self.structure.get_forces()

def get_energy(self):
return self.structure.get_potential_energy()

def get_stress(self):
return self.structure.get_stress(voigt=False)


ASEOutputStatic = OutputStatic(
forces=ASEExecutor.get_forces,
energy=ASEExecutor.get_energy,
stress=ASEExecutor.get_stress,
)


@as_task_dict_evaluator
def evaluate_with_ase(
structure: Atoms,
Expand All @@ -29,93 +53,35 @@ def evaluate_with_ase(
ase_optimizer_kwargs=ase_optimizer_kwargs,
)
elif "optimize_positions_and_volume" in tasks:
results["structure_with_optimized_positions_and_volume"] = (
optimize_positions_and_volume_with_ase(
structure=structure,
ase_calculator=ase_calculator,
ase_optimizer=ase_optimizer,
ase_optimizer_kwargs=ase_optimizer_kwargs,
)
results[
"structure_with_optimized_positions_and_volume"
] = optimize_positions_and_volume_with_ase(
structure=structure,
ase_calculator=ase_calculator,
ase_optimizer=ase_optimizer,
ase_optimizer_kwargs=ase_optimizer_kwargs,
)
elif "calc_energy" in tasks or "calc_forces" in tasks or "calc_stress" in tasks:
if "calc_energy" in tasks and "calc_forces" in tasks and "calc_stress" in tasks:
(
results["energy"],
results["forces"],
results["stress"],
) = calc_energy_forces_and_stress_with_ase(
structure=structure, ase_calculator=ase_calculator
)
elif "calc_energy" in tasks and "calc_forces" in tasks:
results["energy"], results["forces"] = calc_energy_and_forces_with_ase(
structure=structure, ase_calculator=ase_calculator
)
elif "calc_energy" in tasks and "calc_stress" in tasks:
results["energy"], results["forces"] = calc_energy_and_stress_with_ase(
structure=structure, ase_calculator=ase_calculator
)
elif "calc_forces" in tasks and "calc_stress" in tasks:
results["energy"], results["forces"] = calc_forces_and_stress_with_ase(
structure=structure, ase_calculator=ase_calculator
)
elif "calc_energy" in tasks:
results["energy"] = calc_energy_with_ase(
structure=structure, ase_calculator=ase_calculator
)
elif "calc_forces" in tasks:
results["forces"] = calc_forces_with_ase(
structure=structure, ase_calculator=ase_calculator
)
elif "calc_stress" in tasks:
results["stress"] = calc_stress_with_ase(
structure=structure, ase_calculator=ase_calculator
)
return calc_static_with_ase(
structure=structure,
ase_calculator=ase_calculator,
quantities=get_quantities_from_tasks(tasks=tasks),
)
else:
raise ValueError("The ASE calculator does not implement:", tasks)
return results


def calc_energy_with_ase(structure: Atoms, ase_calculator: ASECalculator):
structure.calc = ase_calculator
return structure.get_potential_energy()


def calc_energy_and_forces_with_ase(structure: Atoms, ase_calculator: ASECalculator):
structure.calc = ase_calculator
return structure.get_potential_energy(), structure.get_forces()


def calc_energy_forces_and_stress_with_ase(
structure: Atoms, ase_calculator: ASECalculator
def calc_static_with_ase(
structure,
ase_calculator,
quantities=OutputStatic.fields(),
):
structure.calc = ase_calculator
return (
structure.get_potential_energy(),
structure.get_forces(),
structure.get_stress(voigt=False),
return ASEOutputStatic.get(
ASEExecutor(ase_structure=structure, ase_calculator=ase_calculator), *quantities
)


def calc_forces_and_stress_with_ase(structure: Atoms, ase_calculator: ASECalculator):
structure.calc = ase_calculator
return structure.get_forces(), structure.get_stress(voigt=False)


def calc_energy_and_stress_with_ase(structure: Atoms, ase_calculator: ASECalculator):
structure.calc = ase_calculator
return structure.get_potential_energy(), structure.get_stress(voigt=False)


def calc_forces_with_ase(structure: Atoms, ase_calculator: ASECalculator):
structure.calc = ase_calculator
return structure.get_forces()


def calc_stress_with_ase(structure: Atoms, ase_calculator: ASECalculator):
structure.calc = ase_calculator
return structure.get_stress(voigt=False)


def optimize_positions_with_ase(
structure, ase_calculator, ase_optimizer, ase_optimizer_kwargs
):
Expand Down
11 changes: 11 additions & 0 deletions atomistics/calculators/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,14 @@ class TaskOutputEnum(Enum):
ResultsDict = dict[str, TaskResults]

SimpleEvaluator = callable[[Atoms, list[TaskName], ...], TaskResults]


def get_quantities_from_tasks(tasks):
quantities = []
if "calc_energy" in tasks:
quantities.append("energy")
if "calc_forces" in tasks:
quantities.append("forces")
if "calc_stress" in tasks:
quantities.append("stress")
return quantities
41 changes: 19 additions & 22 deletions atomistics/calculators/lammps/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pylammpsmpi import LammpsASELibrary

from atomistics.calculators.wrapper import as_task_dict_evaluator
from atomistics.calculators.interface import get_quantities_from_tasks
from atomistics.calculators.lammps.helpers import (
lammps_calc_md,
lammps_run,
Expand All @@ -25,7 +25,11 @@
LAMMPS_RUN,
LAMMPS_MINIMIZE_VOLUME,
)
from atomistics.calculators.lammps.output import LammpsMDOutput, LammpsStaticOutput
from atomistics.calculators.lammps.output import (
LammpsOutputMolecularDynamics,
LammpsOutputStatic,
)
from atomistics.calculators.wrapper import as_task_dict_evaluator

if TYPE_CHECKING:
from ase import Atoms
Expand Down Expand Up @@ -113,7 +117,7 @@ def calc_static_with_lammps(
structure,
potential_dataframe,
lmp=None,
quantities=LammpsStaticOutput.fields(),
quantities=LammpsOutputStatic.fields(),
**kwargs,
):
template_str = LAMMPS_THERMO_STYLE + "\n" + LAMMPS_THERMO + "\n" + LAMMPS_RUN
Expand All @@ -127,7 +131,7 @@ def calc_static_with_lammps(
lmp=lmp,
**kwargs,
)
result_dict = LammpsStaticOutput.get(lmp_instance, *quantities)
result_dict = LammpsOutputStatic.get(lmp_instance, *quantities)
lammps_shutdown(lmp_instance=lmp_instance, close_instance=lmp is None)
return result_dict

Expand All @@ -144,7 +148,7 @@ def calc_molecular_dynamics_nvt_with_lammps(
seed=4928459,
dist="gaussian",
lmp=None,
quantities=LammpsMDOutput.fields(),
quantities=LammpsOutputMolecularDynamics.fields(),
**kwargs,
):
init_str = (
Expand Down Expand Up @@ -201,7 +205,7 @@ def calc_molecular_dynamics_npt_with_lammps(
seed=4928459,
dist="gaussian",
lmp=None,
quantities=LammpsMDOutput.fields(),
quantities=LammpsOutputMolecularDynamics.fields(),
**kwargs,
):
init_str = (
Expand Down Expand Up @@ -259,7 +263,7 @@ def calc_molecular_dynamics_nph_with_lammps(
seed=4928459,
dist="gaussian",
lmp=None,
quantities=LammpsMDOutput.fields(),
quantities=LammpsOutputMolecularDynamics.fields(),
**kwargs,
):
init_str = (
Expand Down Expand Up @@ -362,13 +366,13 @@ def evaluate_with_lammps_library(
):
results = {}
if "optimize_positions_and_volume" in tasks:
results["structure_with_optimized_positions_and_volume"] = (
optimize_positions_and_volume_with_lammps(
structure=structure,
potential_dataframe=potential_dataframe,
lmp=lmp,
**lmp_optimizer_kwargs,
)
results[
"structure_with_optimized_positions_and_volume"
] = optimize_positions_and_volume_with_lammps(
structure=structure,
potential_dataframe=potential_dataframe,
lmp=lmp,
**lmp_optimizer_kwargs,
)
elif "optimize_positions" in tasks:
results["structure_with_optimized_positions"] = optimize_positions_with_lammps(
Expand All @@ -389,18 +393,11 @@ def evaluate_with_lammps_library(
)
results["volume_over_temperature"] = (temperature_lst, volume_md_lst)
elif "calc_energy" in tasks or "calc_forces" in tasks or "calc_stress" in tasks:
quantities = []
if "calc_energy" in tasks:
quantities.append("energy")
if "calc_forces" in tasks:
quantities.append("forces")
if "calc_stress" in tasks:
quantities.append("stress")
return calc_static_with_lammps(
structure=structure,
potential_dataframe=potential_dataframe,
lmp=lmp,
quantities=quantities,
quantities=get_quantities_from_tasks(tasks=tasks),
)
else:
raise ValueError("The LAMMPS calculator does not implement:", tasks)
Expand Down
8 changes: 4 additions & 4 deletions atomistics/calculators/lammps/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pylammpsmpi import LammpsASELibrary

from atomistics.calculators.lammps.potential import validate_potential_dataframe
from atomistics.calculators.lammps.output import LammpsMDOutput
from atomistics.calculators.lammps.output import LammpsOutputMolecularDynamics


def lammps_run(structure, potential_dataframe, input_template=None, lmp=None, **kwargs):
Expand Down Expand Up @@ -41,19 +41,19 @@ def lammps_calc_md_step(
lmp_instance,
run_str,
run,
quantities=LammpsMDOutput.fields(),
quantities=LammpsOutputMolecularDynamics.fields(),
):
run_str_rendered = Template(run_str).render(run=run)
lmp_instance.interactive_lib_command(run_str_rendered)
return LammpsMDOutput.get(lmp_instance, *quantities)
return LammpsOutputMolecularDynamics.get(lmp_instance, *quantities)


def lammps_calc_md(
lmp_instance,
run_str,
run,
thermo,
quantities=LammpsMDOutput.fields(),
quantities=LammpsOutputMolecularDynamics.fields(),
):
results_lst = [
lammps_calc_md_step(
Expand Down
46 changes: 16 additions & 30 deletions atomistics/calculators/lammps/output.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,19 @@
import dataclasses

from atomistics.calculators.output import OutputStatic, OutputMolecularDynamics
from pylammpsmpi import LammpsASELibrary


@dataclasses.dataclass
class LammpsOutput:
@classmethod
def fields(cls):
return tuple(field.name for field in dataclasses.fields(cls))

@classmethod
def get(cls, engine: LammpsASELibrary, *quantities: str) -> dict:
return {q: getattr(cls, q)(engine) for q in quantities}


@dataclasses.dataclass
class LammpsMDOutput(LammpsOutput):
positions: callable = LammpsASELibrary.interactive_positions_getter
cell: callable = LammpsASELibrary.interactive_cells_getter
forces: callable = LammpsASELibrary.interactive_forces_getter
temperature: callable = LammpsASELibrary.interactive_temperatures_getter
energy_pot: callable = LammpsASELibrary.interactive_energy_pot_getter
energy_tot: callable = LammpsASELibrary.interactive_energy_tot_getter
pressure: callable = LammpsASELibrary.interactive_pressures_getter
velocities: callable = LammpsASELibrary.interactive_velocities_getter


@dataclasses.dataclass
class LammpsStaticOutput(LammpsOutput):
forces: callable = LammpsASELibrary.interactive_forces_getter
energy: callable = LammpsASELibrary.interactive_energy_pot_getter
stress: callable = LammpsASELibrary.interactive_pressures_getter
LammpsOutputStatic = OutputStatic(
forces=LammpsASELibrary.interactive_forces_getter,
energy=LammpsASELibrary.interactive_energy_pot_getter,
stress=LammpsASELibrary.interactive_pressures_getter,
)
LammpsOutputMolecularDynamics = OutputMolecularDynamics(
positions=LammpsASELibrary.interactive_positions_getter,
cell=LammpsASELibrary.interactive_cells_getter,
forces=LammpsASELibrary.interactive_forces_getter,
temperature=LammpsASELibrary.interactive_temperatures_getter,
energy_pot=LammpsASELibrary.interactive_energy_pot_getter,
energy_tot=LammpsASELibrary.interactive_energy_tot_getter,
pressure=LammpsASELibrary.interactive_pressures_getter,
velocities=LammpsASELibrary.interactive_velocities_getter,
)
Loading
Loading