-
Notifications
You must be signed in to change notification settings - Fork 103
/
Copy pathutils.py
213 lines (182 loc) · 6.25 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
"""Utils for using a force field (aka an interatomic potential).
The following code has been taken and modified from
https://github.com/materialsvirtuallab/m3gnet
The code has been released under BSD 3-Clause License
and the following copyright applies:
Copyright (c) 2022, Materials Virtual Lab.
"""
from __future__ import annotations
import contextlib
import io
import pickle
import sys
import warnings
from typing import TYPE_CHECKING
from ase.optimize import BFGS, FIRE, LBFGS, BFGSLineSearch, LBFGSLineSearch, MDMin
from ase.optimize.sciopt import SciPyFminBFGS, SciPyFminCG
from pymatgen.core.structure import Molecule, Structure
from pymatgen.io.ase import AseAtomsAdaptor
try:
from ase.filters import FrechetCellFilter
except ImportError:
FrechetCellFilter = None
warnings.warn(
"Due to errors in the implementation of gradients in the ASE"
" ExpCellFilter, we recommend installing ASE from gitlab\n"
" pip install git+https://gitlab.com/ase/ase\n"
"rather than PyPi to access FrechetCellFilter. See\n"
" https://wiki.fysik.dtu.dk/ase/ase/filters.html#the-frechetcellfilter-class\n"
"for more details. Otherwise, you must specify an alternate ASE Filter.",
stacklevel=2,
)
if TYPE_CHECKING:
from os import PathLike
from typing import Any
import numpy as np
from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.filters import Filter
from ase.optimize.optimize import Optimizer
OPTIMIZERS = {
"FIRE": FIRE,
"BFGS": BFGS,
"LBFGS": LBFGS,
"LBFGSLineSearch": LBFGSLineSearch,
"MDMin": MDMin,
"SciPyFminCG": SciPyFminCG,
"SciPyFminBFGS": SciPyFminBFGS,
"BFGSLineSearch": BFGSLineSearch,
}
class TrajectoryObserver:
"""Trajectory observer.
This is a hook in the relaxation process that saves the intermediate structures.
"""
def __init__(self, atoms: Atoms) -> None:
"""
Initialize the Observer.
Parameters
----------
atoms (Atoms): the structure to observe.
Returns
-------
None
"""
self.atoms = atoms
self.energies: list[float] = []
self.forces: list[np.ndarray] = []
self.stresses: list[np.ndarray] = []
self.atom_positions: list[np.ndarray] = []
self.cells: list[np.ndarray] = []
def __call__(self) -> None:
"""Save the properties of an Atoms during the relaxation."""
# TODO: maybe include magnetic moments
self.energies.append(self.compute_energy())
self.forces.append(self.atoms.get_forces())
self.stresses.append(self.atoms.get_stress())
self.atom_positions.append(self.atoms.get_positions())
self.cells.append(self.atoms.get_cell()[:])
def compute_energy(self) -> float:
"""
Calculate the energy, here we just use the potential energy.
Returns
-------
energy (float)
"""
return self.atoms.get_potential_energy()
def save(self, filename: str | PathLike) -> None:
"""
Save the trajectory file.
Parameters
----------
filename (str): filename to save the trajectory.
Returns
-------
None
"""
traj_dict = {
"energy": self.energies,
"forces": self.forces,
"stresses": self.stresses,
"atom_positions": self.atom_positions,
"cell": self.cells,
"atomic_number": self.atoms.get_atomic_numbers(),
}
with open(filename, "wb") as file:
pickle.dump(traj_dict, file)
class Relaxer:
"""Relaxer is a class for structural relaxation."""
def __init__(
self,
calculator: Calculator,
optimizer: Optimizer | str = "FIRE",
relax_cell: bool = True,
) -> None:
"""
Initialize the Relaxer.
Parameters
----------
calculator (ase Calculator): an ase calculator
optimizer (str or ase Optimizer): the optimization algorithm.
relax_cell (bool): if True, cell parameters will be optimized.
"""
self.calculator = calculator
if isinstance(optimizer, str):
optimizer_obj = OPTIMIZERS.get(optimizer)
elif optimizer is None:
raise ValueError("Optimizer cannot be None")
else:
optimizer_obj = optimizer
self.opt_class: Optimizer = optimizer_obj
self.relax_cell = relax_cell
self.ase_adaptor = AseAtomsAdaptor()
def relax(
self,
atoms: Atoms,
fmax: float = 0.1,
steps: int = 500,
traj_file: str = None,
interval: int = 1,
verbose: bool = False,
cell_filter: Filter = FrechetCellFilter,
**kwargs,
) -> dict[str, Any]:
"""
Relax the structure.
Parameters
----------
atoms : Atoms
The atoms for relaxation.
fmax : float
Total force tolerance for relaxation convergence.
steps : int
Max number of steps for relaxation.
traj_file : str
The trajectory file for saving.
interval : int
The step interval for saving the trajectories.
verbose : bool
If True, screen output will be shown.
**kwargs
Further kwargs.
Returns
-------
dict including optimized structure and the trajectory
"""
if isinstance(atoms, (Structure, Molecule)):
atoms = self.ase_adaptor.get_atoms(atoms)
atoms.set_calculator(self.calculator)
stream = sys.stdout if verbose else io.StringIO()
with contextlib.redirect_stdout(stream):
obs = TrajectoryObserver(atoms)
if self.relax_cell:
atoms = cell_filter(atoms)
optimizer = self.opt_class(atoms, **kwargs)
optimizer.attach(obs, interval=interval)
optimizer.run(fmax=fmax, steps=steps)
obs()
if traj_file is not None:
obs.save(traj_file)
if isinstance(atoms, cell_filter):
atoms = atoms.atoms
struct = self.ase_adaptor.get_structure(atoms)
return {"final_structure": struct, "trajectory": obs}