Skip to content

Commit

Permalink
CHGNetCalculator add kwarg task: PredTask = "efsm" (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh authored Nov 16, 2024
1 parent 0da2d15 commit 84e8d55
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 61 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
default_stages: [commit]
default_stages: [pre-commit]

default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
rev: v0.7.4
hooks:
- id: ruff
args: [--fix]
Expand All @@ -28,11 +28,11 @@ repos:
rev: v2.3.0
hooks:
- id: codespell
stages: [commit, commit-msg]
stages: [pre-commit, commit-msg]
args: [--check-filenames]

- repo: https://github.com/kynan/nbstripout
rev: 0.7.1
rev: 0.8.0
hooks:
- id: nbstripout
args: [--drop-empty-cells, --keep-output]
Expand All @@ -48,7 +48,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.12.0
rev: v9.15.0
hooks:
- id: eslint
types: [file]
Expand Down
36 changes: 23 additions & 13 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from ase.optimize.optimize import Optimizer
from typing_extensions import Self

from chgnet import PredTask

# We would like to thank M3GNet develop team for this module
# source: https://github.com/materialsvirtuallab/m3gnet

Expand All @@ -59,7 +61,7 @@ def __init__(
*,
use_device: str | None = None,
check_cuda_mem: bool = False,
stress_weight: float | None = 1 / 160.21766208,
stress_weight: float = units.GPa, # GPa to eV/A^3
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
return_site_energies: bool = False,
**kwargs,
Expand Down Expand Up @@ -124,6 +126,7 @@ def calculate(
atoms: Atoms | None = None,
properties: list | None = None,
system_changes: list | None = None,
task: PredTask = "efsm",
) -> None:
"""Calculate various properties of the atoms using CHGNet.
Expand All @@ -133,6 +136,8 @@ def calculate(
Default is all properties.
system_changes (list | None): The changes made to the system.
Default is all changes.
task (PredTask): The task to perform. One of "e", "ef", "em", "efs", "efsm".
Default = "efsm"
"""
properties = properties or all_properties
system_changes = system_changes or all_changes
Expand All @@ -147,23 +152,28 @@ def calculate(
graph = self.model.graph_converter(structure)
model_prediction = self.model.predict_graph(
graph.to(self.device),
task="efsm",
task=task,
return_crystal_feas=True,
return_site_energies=self.return_site_energies,
)

# Convert Result
factor = 1 if not self.model.is_intensive else structure.composition.num_atoms
self.results.update(
energy=model_prediction["e"] * factor,
forces=model_prediction["f"],
free_energy=model_prediction["e"] * factor,
magmoms=model_prediction["m"],
stress=model_prediction["s"] * self.stress_weight,
crystal_fea=model_prediction["crystal_fea"],
extensive_factor = len(structure) if self.model.is_intensive else 1
key_map = dict(
e=("energy", extensive_factor),
f=("forces", 1),
m=("magmoms", 1),
s=("stress", self.stress_weight),
)
self.results |= {
long_key: model_prediction[key] * factor
for key, (long_key, factor) in key_map.items()
if key in model_prediction
}
self.results["free_energy"] = self.results["energy"]
self.results["crystal_fea"] = model_prediction["crystal_fea"]
if self.return_site_energies:
self.results.update(energies=model_prediction["site_energies"])
self.results["energies"] = model_prediction["site_energies"]


class StructOptimizer:
Expand All @@ -174,7 +184,7 @@ def __init__(
model: CHGNet | CHGNetCalculator | None = None,
optimizer_class: Optimizer | str | None = "FIRE",
use_device: str | None = None,
stress_weight: float = 1 / 160.21766208,
stress_weight: float = units.GPa,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
) -> None:
"""Provide a trained CHGNet model and an optimizer to relax crystal structures.
Expand Down Expand Up @@ -773,7 +783,7 @@ def __init__(
model: CHGNet | CHGNetCalculator | None = None,
optimizer_class: Optimizer | str | None = "FIRE",
use_device: str | None = None,
stress_weight: float = 1 / 160.21766208,
stress_weight: float = units.GPa,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "error",
) -> None:
"""Initialize a structure optimizer object for calculation of bulk modulus.
Expand Down
9 changes: 6 additions & 3 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import os
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, get_args

import torch
from pymatgen.core import Structure
from torch import Tensor, nn

from chgnet import PredTask
from chgnet.graph import CrystalGraph, CrystalGraphConverter
from chgnet.graph.crystalgraph import TORCH_DTYPE
from chgnet.model.composition_model import AtomRef
Expand All @@ -27,7 +28,6 @@
if TYPE_CHECKING:
from typing_extensions import Self

from chgnet import PredTask

module_dir = os.path.dirname(os.path.abspath(__file__))

Expand Down Expand Up @@ -603,7 +603,7 @@ def predict_graph(
Args:
graph (CrystalGraph | Sequence[CrystalGraph]): CrystalGraph(s) to predict.
task (str): can be 'e' 'ef', 'em', 'efs', 'efsm'
task (PredTask): one of 'e', 'ef', 'em', 'efs', 'efsm'
Default = "efsm"
return_site_energies (bool): whether to return per-site energies.
Default = False
Expand All @@ -626,6 +626,9 @@ def predict_graph(
raise TypeError(
f"{type(graph)=} must be CrystalGraph or list of CrystalGraphs"
)
valid_tasks = get_args(PredTask)
if task not in valid_tasks:
raise ValueError(f"Invalid {task=}. Must be one of {valid_tasks}.")

model_device = next(self.parameters()).device

Expand Down
2 changes: 1 addition & 1 deletion chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ def forward(
for mag_pred, mag_target in zip(prediction["m"], targets["m"], strict=True):
# exclude structures without magmom labels
if self.allow_missing_labels:
if mag_target is not None and not np.isnan(mag_target).any():
if mag_target is not None and not torch.isnan(mag_target).any():
mag_preds.append(mag_pred)
mag_targets.append(mag_target)
m_mae_size += mag_target.shape[0]
Expand Down
1 change: 0 additions & 1 deletion site/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ node_modules
.svelte-kit
build
src/routes/api/*.md
src/MetricsTable.svelte
36 changes: 18 additions & 18 deletions site/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,28 @@
"changelog": "npx auto-changelog --package --output ../changelog.md --hide-credit --commit-limit false"
},
"devDependencies": {
"@sveltejs/adapter-static": "^3.0.2",
"@sveltejs/kit": "^2.5.17",
"@sveltejs/vite-plugin-svelte": "^3.1.1",
"eslint": "^9.5.0",
"eslint-plugin-svelte": "^2.41.0",
"@sveltejs/adapter-static": "^3.0.6",
"@sveltejs/kit": "^2.8.1",
"@sveltejs/vite-plugin-svelte": "^4.0.1",
"eslint": "^9.15.0",
"eslint-plugin-svelte": "^2.46.0",
"hastscript": "^9.0.0",
"mdsvex": "^0.11.2",
"prettier": "^3.3.2",
"prettier-plugin-svelte": "^3.2.5",
"mdsvex": "^0.12.3",
"prettier": "^3.3.3",
"prettier-plugin-svelte": "^3.2.8",
"rehype-autolink-headings": "^7.1.0",
"rehype-slug": "^6.0.0",
"svelte": "^4.2.18",
"svelte-check": "^3.8.4",
"svelte-multiselect": "^10.3.0",
"svelte-preprocess": "^6.0.1",
"svelte": "^5.2.1",
"svelte-check": "^4.0.8",
"svelte-multiselect": "11.0.0-rc.1",
"svelte-preprocess": "^6.0.3",
"svelte-toc": "^0.5.9",
"svelte-zoo": "^0.4.10",
"svelte2tsx": "^0.7.13",
"tslib": "^2.6.3",
"typescript": "^5.5.2",
"typescript-eslint": "^7.14.1",
"vite": "^5.3.1"
"svelte-zoo": "^0.4.13",
"svelte2tsx": "^0.7.25",
"tslib": "^2.8.1",
"typescript": "^5.6.3",
"typescript-eslint": "^8.14.0",
"vite": "^5.4.11"
},
"prettier": {
"semi": false,
Expand Down
5 changes: 1 addition & 4 deletions site/src/routes/+page.svelte
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
<script lang="ts">
import Readme from '$root/README.md'
import MetricsTable from '$src/MetricsTable.svelte'
</script>

<main>
<Readme>
<MetricsTable slot="metrics-table" />
</Readme>
<Readme />
</main>

<style>
Expand Down
10 changes: 0 additions & 10 deletions site/vite.config.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
import { sveltekit } from '@sveltejs/kit/vite'
import * as fs from 'fs'
import type { UserConfig } from 'vite'

// fetch latest Matbench Discovery metrics table at build time and save to src/ dir
await fetch(
`https://github.com/janosh/matbench-discovery/raw/main/site/src/figs/metrics-table-uniq-protos.svelte`,
)
.then((res) => res.text())
.then((text) => {
fs.writeFileSync(`src/MetricsTable.svelte`, text)
})

export default {
plugins: [sveltekit()],

Expand Down
28 changes: 26 additions & 2 deletions tests/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import pickle
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, get_args

import numpy as np
import pytest
Expand All @@ -22,7 +22,7 @@
from chgnet.graph import CrystalGraphConverter
from chgnet.model import StructOptimizer
from chgnet.model.dynamics import CHGNetCalculator, EquationOfState, MolecularDynamics
from chgnet.model.model import CHGNet
from chgnet.model.model import CHGNet, PredTask

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -314,3 +314,27 @@ def test_md_crystal_feas_log(tmp_path: Path, monkeypatch: MonkeyPatch):
assert crystal_feas[0][1] == approx(-1.4285042, abs=1e-5)
assert crystal_feas[10][0] == approx(-0.0020592688, abs=1e-5)
assert crystal_feas[10][1] == approx(-1.4284436, abs=1e-5)


@pytest.mark.parametrize("task", [*get_args(PredTask)])
def test_calculator_task_valid(task: PredTask):
"""Test that the task kwarg of CHGNetCalculator.calculate() works correctly."""
key_map = dict(e="energy", f="forces", m="magmoms", s="stress")
calculator = CHGNetCalculator()
atoms = AseAtomsAdaptor.get_atoms(structure)
atoms.calc = calculator

calculator.calculate(atoms=atoms, task=task)

for key, prop in key_map.items():
assert (prop in calculator.results) == (key in task)


def test_calculator_task_invalid():
"""Test that invalid task raises ValueError."""
calculator = CHGNetCalculator()
atoms = AseAtomsAdaptor.get_atoms(structure)
atoms.calc = calculator

with pytest.raises(ValueError, match="Invalid task='invalid'."):
calculator.calculate(atoms=atoms, task="invalid")
4 changes: 3 additions & 1 deletion tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def test_relaxation(
assert {*traj.__dict__} == {
*"atoms energies forces stresses magmoms atom_positions cells".split()
}
assert len(traj) == 2 if algorithm == "legacy" else 4
assert len(traj) == (
2 if algorithm == "legacy" else 4
), f"{len(traj)=}, {algorithm=}"

# make sure final structure is more relaxed than initial one
assert traj.energies[-1] == pytest.approx(-58.94209, rel=1e-4)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
for param in chgnet.composition_model.parameters():
assert param.requires_grad is False
assert tmp_path.is_dir(), "Training dir was not created"
for target_str in ["e", "f", "s", "m"]:
assert ~np.isnan(trainer.training_history[target_str]["train"]).any()
assert ~np.isnan(trainer.training_history[target_str]["val"]).any()
for prop in "efsm":
assert ~np.isnan(trainer.training_history[prop]["train"]).any()
assert ~np.isnan(trainer.training_history[prop]["val"]).any()
output_files = [file.name for file in tmp_path.iterdir()]
for prefix in ("epoch", "bestE_", "bestF_"):
n_matches = sum(file.startswith(prefix) for file in output_files)
Expand Down

0 comments on commit 84e8d55

Please sign in to comment.