Skip to content

Commit

Permalink
MACE Static/RelaxMakers default to loading mace_mp instead of test model
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Nov 16, 2023
1 parent 0738890 commit c74e0f1
Showing 1 changed file with 20 additions and 22 deletions.
42 changes: 20 additions & 22 deletions src/atomate2/forcefields/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,12 @@ class MACERelaxMaker(ForceFieldRelaxMaker):
Keyword arguments that will get passed to :obj:`Relaxer()`.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
potential_param_file_name: str | Path
param_file_name for :obj:`mace.calculators.MACECalculator()'`.
potential_kwargs: dict[str, Any]
model_path: str | Path | None
Checkpoint to load with :obj:`mace.calculators.MACECalculator()'`. Can be a URL
starting with https://. If None, loads the universal MACE trained for Matbench
Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
model_kwargs: dict[str, Any]
Further keywords (e.g. device, default_dtype, model_paths) for
:obj:`mace.calculators.MACECalculator()'`.
"""
Expand All @@ -323,17 +326,13 @@ class MACERelaxMaker(ForceFieldRelaxMaker):
relax_kwargs: dict = field(default_factory=dict)
optimizer_kwargs: dict = field(default_factory=dict)
task_document_kwargs: dict = field(default_factory=dict)
potential_param_file_name: str = "MACE.model"
potential_kwargs: dict = field(default_factory=dict)
model_path: str = None
model_kwargs: dict = field(default_factory=dict)

def _relax(self, structure: Structure) -> dict:
from mace.calculators import MACECalculator
from mace.calculators import mace_mp

self.potential_kwargs.setdefault("device", "auto")

calculator = MACECalculator(
model_paths=self.potential_param_file_name, **self.potential_kwargs
)
calculator = mace_mp(model_path=self.model_path, **self.model_kwargs)
relaxer = Relaxer(calculator, relax_cell=self.relax_cell)
return relaxer.relax(structure, steps=self.steps, **self.relax_kwargs)

Expand All @@ -351,27 +350,26 @@ class MACEStaticMaker(ForceFieldStaticMaker):
The name of the force field.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
potential_param_file_name: str | Path
param_file_name for :obj:`mace.calculators.MACECalculator()'`.
potential_kwargs: dict[str, Any]
model_path: str | Path | None
Checkpoint to load with :obj:`mace.calculators.MACECalculator()'`. Can be a URL
starting with https://. If None, loads the universal MACE trained for Matbench
Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
model_kwargs: dict[str, Any]
Further keywords (e.g. device, default_dtype, model_paths) for
:obj:`mace.calculators.MACECalculator()'`.
"""

name: str = "MACE static"
force_field_name: str = "MACE"
task_document_kwargs: dict = field(default_factory=dict)
potential_param_file_name: str = "MACE.model"
potential_kwargs: dict = field(default_factory=dict)
model_path: str = None
model_kwargs: dict = field(default_factory=dict)

def _evaluate_static(self, structure: Structure) -> dict:
from mace.calculators import MACECalculator
from mace.calculators import mace_mp

self.potential_kwargs.setdefault("device", "auto")

calculator = MACECalculator(
model_paths=self.potential_param_file_name, **self.potential_kwargs
)
calculator = mace_mp(model_path=self.model_path, **self.model_kwargs)
relaxer = Relaxer(calculator, relax_cell=False)
return relaxer.relax(structure, steps=1)

Expand Down

0 comments on commit c74e0f1

Please sign in to comment.