Skip to content

Commit

Permalink
Merge pull request #108 from ACEsuit/lammps_wrapper
Browse files Browse the repository at this point in the history
Lammps wrapper
  • Loading branch information
ilyes319 authored Jun 1, 2023
2 parents 538b03c + 8676e93 commit 3f909d8
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 48 deletions.
92 changes: 44 additions & 48 deletions mace/calculators/lammps_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,13 @@ def __init__(self, model):
def forward(
self,
data: Dict[str, torch.Tensor],
mask_ghost: torch.Tensor,
compute_force: bool = True,
local_or_ghost: torch.Tensor,
compute_virials: bool = False,
compute_stress: bool = False,
) -> Dict[str, Optional[torch.Tensor]]:
num_graphs = data["ptr"].numel() - 1
compute_displacement = False
if compute_virials or compute_stress:
if compute_virials:
compute_displacement = True

out = self.model(
data,
training=False,
Expand All @@ -41,59 +38,58 @@ def forward(
)
node_energy = out["node_energy"]
if node_energy is None:
return {"energy": None, "forces": None, "virials": None, "stress": None}
return {
"total_energy_local": None,
"node_energy": None,
"forces": None,
"virials": None,
}
positions = data["positions"]
displacement = out["displacement"]
forces = torch.zeros_like(positions)
virials: Optional[torch.Tensor] = torch.zeros_like(data["cell"])
stress: Optional[torch.Tensor] = torch.zeros_like(data["cell"])
if mask_ghost is not None and displacement is not None:
# displacement.requires_grad_(True) # For some reason torchscript needs that.
node_energy_ghost = node_energy * mask_ghost
total_energy_ghost = scatter_sum(
src=node_energy_ghost, index=data["batch"], dim=-1, dim_size=num_graphs
)
grad_outputs: List[Optional[torch.Tensor]] = [
torch.ones_like(total_energy_ghost)
]
virials = torch.autograd.grad(
outputs=[total_energy_ghost],
inputs=[displacement],
# accumulate energies of local atoms
node_energy_local = node_energy * local_or_ghost
total_energy_local = scatter_sum(
src=node_energy_local, index=data["batch"], dim=-1, dim_size=num_graphs
)
# compute partial forces and (possibly) partial virials
grad_outputs: List[Optional[torch.Tensor]] = [
torch.ones_like(total_energy_local)
]
if compute_virials:
forces, virials = torch.autograd.grad(
outputs=[total_energy_local],
inputs=[positions, displacement],
grad_outputs=grad_outputs,
retain_graph=True,
create_graph=True,
retain_graph=False,
create_graph=False,
allow_unused=True,
)[0]

)
if forces is not None:
forces = -1 * forces
else:
forces = torch.zeros_like(positions)
if virials is not None:
virials = -1 * virials
cell = data["cell"].view(-1, 3, 3)
volume = torch.einsum(
"zi,zi->z",
cell[:, 0, :],
torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1),
).unsqueeze(-1)
stress = virials / volume.view(-1, 1, 1)
else:
virials = torch.zeros_like(displacement)

total_energy = scatter_sum(
src=node_energy, index=data["batch"], dim=-1, dim_size=num_graphs
)

forces, _, _ = get_outputs(
energy=total_energy,
positions=data["positions"],
displacement=displacement,
cell=data["cell"],
training=False,
compute_force=compute_force,
compute_virials=False,
compute_stress=False,
)

else:
forces = torch.autograd.grad(
outputs=[total_energy_local],
inputs=[positions],
grad_outputs=grad_outputs,
retain_graph=False,
create_graph=False,
allow_unused=True,
)[0]
if forces is not None:
forces = -1 * forces
else:
forces = torch.zeros_like(positions)
return {
"energy": total_energy,
"total_energy_local": total_energy_local,
"node_energy": node_energy,
"forces": forces,
"virials": virials,
"stress": stress,
}
11 changes: 11 additions & 0 deletions scripts/create_lammps_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from e3nn.util import jit
import sys
import torch
from mace.calculators import LAMMPS_MACE

model_path = sys.argv[1] # takes model name as command-line input
model = torch.load(model_path)
model = model.double().to("cpu")
lammps_model = LAMMPS_MACE(model)
lammps_model_compiled = jit.compile(lammps_model)
lammps_model_compiled.save(model_path+"-lammps.pt")

0 comments on commit 3f909d8

Please sign in to comment.