Skip to content

Commit

Permalink
mace_mp() pass model_path as model to adapt new API added in ACEsuit/…
Browse files Browse the repository at this point in the history
…mace#230 and silence deprecation warning
  • Loading branch information
janosh committed Nov 20, 2023
1 parent dd4b374 commit 25ccc98
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
17 changes: 9 additions & 8 deletions src/atomate2/forcefields/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from atomate2.forcefields.utils import Relaxer

if TYPE_CHECKING:
from collections.abc import Sequence
from pathlib import Path

from pymatgen.core.structure import Structure
Expand Down Expand Up @@ -309,13 +310,13 @@ class MACERelaxMaker(ForceFieldRelaxMaker):
Keyword arguments that will get passed to :obj:`Relaxer()`.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
model_path: str | Path | None
model: str | Path | None
Checkpoint to load with :obj:`mace.calculators.MACECalculator()'`. Can be a URL
starting with https://. If None, loads the universal MACE trained for Matbench
Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
model_kwargs: dict[str, Any]
Further keywords (e.g. device, default_dtype, model_paths) for
Further keywords (e.g. device, default_dtype, model) for
:obj:`mace.calculators.MACECalculator()'`.
"""

Expand All @@ -326,13 +327,13 @@ class MACERelaxMaker(ForceFieldRelaxMaker):
relax_kwargs: dict = field(default_factory=dict)
optimizer_kwargs: dict = field(default_factory=dict)
task_document_kwargs: dict = field(default_factory=dict)
model_path: str = None
model: str | Path | Sequence[str | Path] | None = None
model_kwargs: dict = field(default_factory=dict)

def _relax(self, structure: Structure) -> dict:
from mace.calculators import mace_mp

calculator = mace_mp(model_path=self.model_path, **self.model_kwargs)
calculator = mace_mp(model=self.model, **self.model_kwargs)
relaxer = Relaxer(
calculator, relax_cell=self.relax_cell, **self.optimizer_kwargs
)
Expand All @@ -352,26 +353,26 @@ class MACEStaticMaker(ForceFieldStaticMaker):
The name of the force field.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
model_path: str | Path | None
model: str | Path | None
Checkpoint to load with :obj:`mace.calculators.MACECalculator()'`. Can be a URL
starting with https://. If None, loads the universal MACE trained for Matbench
Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
model_kwargs: dict[str, Any]
Further keywords (e.g. device, default_dtype, model_paths) for
Further keywords (e.g. device, default_dtype, model) for
:obj:`mace.calculators.MACECalculator()'`.
"""

name: str = "MACE static"
force_field_name: str = "MACE"
task_document_kwargs: dict = field(default_factory=dict)
model_path: str = None
model: str | Path | Sequence[str | Path] | None = None
model_kwargs: dict = field(default_factory=dict)

def _evaluate_static(self, structure: Structure) -> dict:
from mace.calculators import mace_mp

calculator = mace_mp(model_path=self.model_path, **self.model_kwargs)
calculator = mace_mp(model=self.model, **self.model_kwargs)
relaxer = Relaxer(calculator, relax_cell=False)
return relaxer.relax(structure, steps=1)

Expand Down
2 changes: 1 addition & 1 deletion src/atomate2/vasp/flows/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class PhononMaker(Maker):
forces are computed for these structures. With the help of phonopy, these
forces are then converted into a dynamical matrix. To correct for polarization
effects, a correction of the dynamical matrix based on BORN charges can
be performed. Finally, phonon densities of states, phonon band structures
be performed. Finally, phonon densities of states, phonon band structures
and thermodynamic properties are computed.
.. Note::
Expand Down

0 comments on commit 25ccc98

Please sign in to comment.