diff --git a/src/atomate2/forcefields/jobs.py b/src/atomate2/forcefields/jobs.py index 8e1019ddee..8c392bafa5 100644 --- a/src/atomate2/forcefields/jobs.py +++ b/src/atomate2/forcefields/jobs.py @@ -309,9 +309,12 @@ class MACERelaxMaker(ForceFieldRelaxMaker): Keyword arguments that will get passed to :obj:`Relaxer()`. task_document_kwargs : dict Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`. - potential_param_file_name: str | Path - param_file_name for :obj:`mace.calculators.MACECalculator()'`. - potential_kwargs: dict[str, Any] + model_path: str | Path | None + Checkpoint to load with :obj:`mace.calculators.MACECalculator()'`. Can be a URL + starting with https://. If None, loads the universal MACE trained for Matbench + Discovery on the MPtrj dataset available at + https://figshare.com/articles/dataset/22715158. + model_kwargs: dict[str, Any] Further keywords (e.g. device, default_dtype, model_paths) for :obj:`mace.calculators.MACECalculator()'`. """ @@ -323,17 +326,13 @@ class MACERelaxMaker(ForceFieldRelaxMaker): relax_kwargs: dict = field(default_factory=dict) optimizer_kwargs: dict = field(default_factory=dict) task_document_kwargs: dict = field(default_factory=dict) - potential_param_file_name: str = "MACE.model" - potential_kwargs: dict = field(default_factory=dict) + model_path: str = None + model_kwargs: dict = field(default_factory=dict) def _relax(self, structure: Structure) -> dict: - from mace.calculators import MACECalculator + from mace.calculators import mace_mp - self.potential_kwargs.setdefault("device", "auto") - - calculator = MACECalculator( - model_paths=self.potential_param_file_name, **self.potential_kwargs - ) + calculator = mace_mp(model_path=self.model_path, **self.model_kwargs) relaxer = Relaxer(calculator, relax_cell=self.relax_cell) return relaxer.relax(structure, steps=self.steps, **self.relax_kwargs) @@ -351,9 +350,12 @@ class MACEStaticMaker(ForceFieldStaticMaker): The name of the force field. task_document_kwargs : dict Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`. - potential_param_file_name: str | Path - param_file_name for :obj:`mace.calculators.MACECalculator()'`. - potential_kwargs: dict[str, Any] + model_path: str | Path | None + Checkpoint to load with :obj:`mace.calculators.MACECalculator()'`. Can be a URL + starting with https://. If None, loads the universal MACE trained for Matbench + Discovery on the MPtrj dataset available at + https://figshare.com/articles/dataset/22715158. + model_kwargs: dict[str, Any] Further keywords (e.g. device, default_dtype, model_paths) for :obj:`mace.calculators.MACECalculator()'`. """ @@ -361,17 +363,13 @@ class MACEStaticMaker(ForceFieldStaticMaker): name: str = "MACE static" force_field_name: str = "MACE" task_document_kwargs: dict = field(default_factory=dict) - potential_param_file_name: str = "MACE.model" - potential_kwargs: dict = field(default_factory=dict) + model_path: str = None + model_kwargs: dict = field(default_factory=dict) def _evaluate_static(self, structure: Structure) -> dict: - from mace.calculators import MACECalculator + from mace.calculators import mace_mp - self.potential_kwargs.setdefault("device", "auto") - - calculator = MACECalculator( - model_paths=self.potential_param_file_name, **self.potential_kwargs - ) + calculator = mace_mp(model_path=self.model_path, **self.model_kwargs) relaxer = Relaxer(calculator, relax_cell=False) return relaxer.relax(structure, steps=1)