Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ase_filter keyword to StructOptimizer.relax() #102

Merged
merged 3 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from ase import Atoms, units
from ase.calculators.calculator import Calculator, all_changes, all_properties
from ase.filters import FrechetCellFilter
from ase.filters import Filter, FrechetCellFilter
from ase.md.npt import NPT
from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen
from ase.md.nvtberendsen import NVTBerendsen
Expand Down Expand Up @@ -211,6 +211,7 @@ def relax(
fmax: float | None = 0.1,
steps: int | None = 500,
relax_cell: bool | None = True,
ase_filter: Filter = FrechetCellFilter,
save_path: str | None = None,
loginterval: int | None = 1,
crystal_feas_save_path: str | None = None,
Expand All @@ -227,6 +228,11 @@ def relax(
Default = 500
relax_cell (bool | None): Whether to relax the cell as well.
Default = True
ase_filter (ase.filters.Filter): The filter to apply to the atoms object
for relaxation. Default = FrechetCellFilter
Used to default to ExpCellFilter but was removed due to bug reported in
https://gitlab.com/ase/ase/-/issues/1321 and fixed in
https://gitlab.com/ase/ase/-/merge_requests/3024.
save_path (str | None): The path to save the trajectory.
Default = None
loginterval (int | None): Interval for logging trajectory and crystal feas
Expand Down Expand Up @@ -255,7 +261,7 @@ def relax(
cry_obs = CrystalFeasObserver(atoms)

if relax_cell:
atoms = FrechetCellFilter(atoms)
atoms = ase_filter(atoms)
optimizer = self.optimizer_class(atoms, **kwargs)
optimizer.attach(obs, interval=loginterval)

Expand All @@ -271,7 +277,7 @@ def relax(
if crystal_feas_save_path:
cry_obs.save(crystal_feas_save_path)

if isinstance(atoms, FrechetCellFilter):
if isinstance(atoms, Filter):
atoms = atoms.atoms
struct = AseAtomsAdaptor.get_structure(atoms)
for key in struct.site_properties:
Expand Down
9 changes: 6 additions & 3 deletions tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
import torch
from ase.filters import ExpCellFilter, Filter, FrechetCellFilter
from pymatgen.core import Structure
from pytest import approx, mark, param

Expand All @@ -15,8 +16,10 @@
structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")


@pytest.mark.parametrize("algorithm", ["legacy", "fast"])
def test_relaxation(algorithm: Literal["legacy", "fast"]):
@pytest.mark.parametrize(
"algorithm, ase_filter", [("legacy", FrechetCellFilter), ("fast", ExpCellFilter)]
)
def test_relaxation(algorithm: Literal["legacy", "fast"], ase_filter: Filter) -> None:
chgnet = CHGNet.load()
converter = CrystalGraphConverter(
atom_graph_cutoff=6, bond_graph_cutoff=3, algorithm=algorithm
Expand All @@ -25,7 +28,7 @@ def test_relaxation(algorithm: Literal["legacy", "fast"]):

chgnet.graph_converter = converter
relaxer = StructOptimizer(model=chgnet)
result = relaxer.relax(structure, verbose=True)
result = relaxer.relax(structure, verbose=True, ase_filter=ase_filter)
assert list(result) == ["final_structure", "trajectory"]

traj = result["trajectory"]
Expand Down