From 84e8d55132b2242fad06f9b7f7706b35c1a7da7e Mon Sep 17 00:00:00 2001
From: Janosh Riebesell <janosh.riebesell@gmail.com>
Date: Sat, 16 Nov 2024 21:14:27 +0000
Subject: [PATCH] CHGNetCalculator add kwarg task: PredTask = "efsm" (#215)

---
 .pre-commit-config.yaml      | 10 +++++-----
 chgnet/model/dynamics.py     | 36 +++++++++++++++++++++++-------------
 chgnet/model/model.py        |  9 ++++++---
 chgnet/trainer/trainer.py    |  2 +-
 site/.gitignore              |  1 -
 site/package.json            | 36 ++++++++++++++++++------------------
 site/src/routes/+page.svelte |  5 +----
 site/vite.config.ts          | 10 ----------
 tests/test_md.py             | 28 ++++++++++++++++++++++++++--
 tests/test_relaxation.py     |  4 +++-
 tests/test_trainer.py        |  6 +++---
 11 files changed, 86 insertions(+), 61 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 5f0a13d2..bc3acb2d 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,10 +1,10 @@
-default_stages: [commit]
+default_stages: [pre-commit]
 
 default_install_hook_types: [pre-commit, commit-msg]
 
 repos:
   - repo: https://github.com/astral-sh/ruff-pre-commit
-    rev: v0.6.9
+    rev: v0.7.4
     hooks:
       - id: ruff
         args: [--fix]
@@ -28,11 +28,11 @@ repos:
     rev: v2.3.0
     hooks:
       - id: codespell
-        stages: [commit, commit-msg]
+        stages: [pre-commit, commit-msg]
         args: [--check-filenames]
 
   - repo: https://github.com/kynan/nbstripout
-    rev: 0.7.1
+    rev: 0.8.0
     hooks:
       - id: nbstripout
         args: [--drop-empty-cells, --keep-output]
@@ -48,7 +48,7 @@ repos:
           - svelte
 
   - repo: https://github.com/pre-commit/mirrors-eslint
-    rev: v9.12.0
+    rev: v9.15.0
     hooks:
       - id: eslint
         types: [file]
diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py
index b5b01f97..8b03bf0b 100644
--- a/chgnet/model/dynamics.py
+++ b/chgnet/model/dynamics.py
@@ -33,6 +33,8 @@
     from ase.optimize.optimize import Optimizer
     from typing_extensions import Self
 
+    from chgnet import PredTask
+
 # We would like to thank M3GNet develop team for this module
 # source: https://github.com/materialsvirtuallab/m3gnet
 
@@ -59,7 +61,7 @@ def __init__(
         *,
         use_device: str | None = None,
         check_cuda_mem: bool = False,
-        stress_weight: float | None = 1 / 160.21766208,
+        stress_weight: float = units.GPa,  # GPa to eV/A^3
         on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
         return_site_energies: bool = False,
         **kwargs,
@@ -124,6 +126,7 @@ def calculate(
         atoms: Atoms | None = None,
         properties: list | None = None,
         system_changes: list | None = None,
+        task: PredTask = "efsm",
     ) -> None:
         """Calculate various properties of the atoms using CHGNet.
 
@@ -133,6 +136,8 @@ def calculate(
                 Default is all properties.
             system_changes (list | None): The changes made to the system.
                 Default is all changes.
+            task (PredTask): The task to perform. One of "e", "ef", "em", "efs", "efsm".
+                Default = "efsm"
         """
         properties = properties or all_properties
         system_changes = system_changes or all_changes
@@ -147,23 +152,28 @@ def calculate(
         graph = self.model.graph_converter(structure)
         model_prediction = self.model.predict_graph(
             graph.to(self.device),
-            task="efsm",
+            task=task,
             return_crystal_feas=True,
             return_site_energies=self.return_site_energies,
         )
 
         # Convert Result
-        factor = 1 if not self.model.is_intensive else structure.composition.num_atoms
-        self.results.update(
-            energy=model_prediction["e"] * factor,
-            forces=model_prediction["f"],
-            free_energy=model_prediction["e"] * factor,
-            magmoms=model_prediction["m"],
-            stress=model_prediction["s"] * self.stress_weight,
-            crystal_fea=model_prediction["crystal_fea"],
+        extensive_factor = len(structure) if self.model.is_intensive else 1
+        key_map = dict(
+            e=("energy", extensive_factor),
+            f=("forces", 1),
+            m=("magmoms", 1),
+            s=("stress", self.stress_weight),
         )
+        self.results |= {
+            long_key: model_prediction[key] * factor
+            for key, (long_key, factor) in key_map.items()
+            if key in model_prediction
+        }
+        self.results["free_energy"] = self.results["energy"]
+        self.results["crystal_fea"] = model_prediction["crystal_fea"]
         if self.return_site_energies:
-            self.results.update(energies=model_prediction["site_energies"])
+            self.results["energies"] = model_prediction["site_energies"]
 
 
 class StructOptimizer:
@@ -174,7 +184,7 @@ def __init__(
         model: CHGNet | CHGNetCalculator | None = None,
         optimizer_class: Optimizer | str | None = "FIRE",
         use_device: str | None = None,
-        stress_weight: float = 1 / 160.21766208,
+        stress_weight: float = units.GPa,
         on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
     ) -> None:
         """Provide a trained CHGNet model and an optimizer to relax crystal structures.
@@ -773,7 +783,7 @@ def __init__(
         model: CHGNet | CHGNetCalculator | None = None,
         optimizer_class: Optimizer | str | None = "FIRE",
         use_device: str | None = None,
-        stress_weight: float = 1 / 160.21766208,
+        stress_weight: float = units.GPa,
         on_isolated_atoms: Literal["ignore", "warn", "error"] = "error",
     ) -> None:
         """Initialize a structure optimizer object for calculation of bulk modulus.
diff --git a/chgnet/model/model.py b/chgnet/model/model.py
index d42c61c9..c1bd58f8 100644
--- a/chgnet/model/model.py
+++ b/chgnet/model/model.py
@@ -4,12 +4,13 @@
 import os
 from collections.abc import Sequence
 from dataclasses import dataclass
-from typing import TYPE_CHECKING, Literal
+from typing import TYPE_CHECKING, Literal, get_args
 
 import torch
 from pymatgen.core import Structure
 from torch import Tensor, nn
 
+from chgnet import PredTask
 from chgnet.graph import CrystalGraph, CrystalGraphConverter
 from chgnet.graph.crystalgraph import TORCH_DTYPE
 from chgnet.model.composition_model import AtomRef
@@ -27,7 +28,6 @@
 if TYPE_CHECKING:
     from typing_extensions import Self
 
-    from chgnet import PredTask
 
 module_dir = os.path.dirname(os.path.abspath(__file__))
 
@@ -603,7 +603,7 @@ def predict_graph(
 
         Args:
             graph (CrystalGraph | Sequence[CrystalGraph]): CrystalGraph(s) to predict.
-            task (str): can be 'e' 'ef', 'em', 'efs', 'efsm'
+            task (PredTask): one of 'e', 'ef', 'em', 'efs', 'efsm'
                 Default = "efsm"
             return_site_energies (bool): whether to return per-site energies.
                 Default = False
@@ -626,6 +626,9 @@ def predict_graph(
             raise TypeError(
                 f"{type(graph)=} must be CrystalGraph or list of CrystalGraphs"
             )
+        valid_tasks = get_args(PredTask)
+        if task not in valid_tasks:
+            raise ValueError(f"Invalid {task=}. Must be one of {valid_tasks}.")
 
         model_device = next(self.parameters()).device
 
diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py
index e3637212..b742118a 100644
--- a/chgnet/trainer/trainer.py
+++ b/chgnet/trainer/trainer.py
@@ -858,7 +858,7 @@ def forward(
             for mag_pred, mag_target in zip(prediction["m"], targets["m"], strict=True):
                 # exclude structures without magmom labels
                 if self.allow_missing_labels:
-                    if mag_target is not None and not np.isnan(mag_target).any():
+                    if mag_target is not None and not torch.isnan(mag_target).any():
                         mag_preds.append(mag_pred)
                         mag_targets.append(mag_target)
                         m_mae_size += mag_target.shape[0]
diff --git a/site/.gitignore b/site/.gitignore
index 59078f29..bded1f72 100644
--- a/site/.gitignore
+++ b/site/.gitignore
@@ -5,4 +5,3 @@ node_modules
 .svelte-kit
 build
 src/routes/api/*.md
-src/MetricsTable.svelte
diff --git a/site/package.json b/site/package.json
index 3474e4be..2f8156fc 100644
--- a/site/package.json
+++ b/site/package.json
@@ -15,28 +15,28 @@
     "changelog": "npx auto-changelog --package --output ../changelog.md --hide-credit --commit-limit false"
   },
   "devDependencies": {
-    "@sveltejs/adapter-static": "^3.0.2",
-    "@sveltejs/kit": "^2.5.17",
-    "@sveltejs/vite-plugin-svelte": "^3.1.1",
-    "eslint": "^9.5.0",
-    "eslint-plugin-svelte": "^2.41.0",
+    "@sveltejs/adapter-static": "^3.0.6",
+    "@sveltejs/kit": "^2.8.1",
+    "@sveltejs/vite-plugin-svelte": "^4.0.1",
+    "eslint": "^9.15.0",
+    "eslint-plugin-svelte": "^2.46.0",
     "hastscript": "^9.0.0",
-    "mdsvex": "^0.11.2",
-    "prettier": "^3.3.2",
-    "prettier-plugin-svelte": "^3.2.5",
+    "mdsvex": "^0.12.3",
+    "prettier": "^3.3.3",
+    "prettier-plugin-svelte": "^3.2.8",
     "rehype-autolink-headings": "^7.1.0",
     "rehype-slug": "^6.0.0",
-    "svelte": "^4.2.18",
-    "svelte-check": "^3.8.4",
-    "svelte-multiselect": "^10.3.0",
-    "svelte-preprocess": "^6.0.1",
+    "svelte": "^5.2.1",
+    "svelte-check": "^4.0.8",
+    "svelte-multiselect": "11.0.0-rc.1",
+    "svelte-preprocess": "^6.0.3",
     "svelte-toc": "^0.5.9",
-    "svelte-zoo": "^0.4.10",
-    "svelte2tsx": "^0.7.13",
-    "tslib": "^2.6.3",
-    "typescript": "^5.5.2",
-    "typescript-eslint": "^7.14.1",
-    "vite": "^5.3.1"
+    "svelte-zoo": "^0.4.13",
+    "svelte2tsx": "^0.7.25",
+    "tslib": "^2.8.1",
+    "typescript": "^5.6.3",
+    "typescript-eslint": "^8.14.0",
+    "vite": "^5.4.11"
   },
   "prettier": {
     "semi": false,
diff --git a/site/src/routes/+page.svelte b/site/src/routes/+page.svelte
index 7e2c6975..201fe721 100644
--- a/site/src/routes/+page.svelte
+++ b/site/src/routes/+page.svelte
@@ -1,12 +1,9 @@
 <script lang="ts">
   import Readme from '$root/README.md'
-  import MetricsTable from '$src/MetricsTable.svelte'
 </script>
 
 <main>
-  <Readme>
-    <MetricsTable slot="metrics-table" />
-  </Readme>
+  <Readme />
 </main>
 
 <style>
diff --git a/site/vite.config.ts b/site/vite.config.ts
index c0cd4a6d..7765a169 100644
--- a/site/vite.config.ts
+++ b/site/vite.config.ts
@@ -1,16 +1,6 @@
 import { sveltekit } from '@sveltejs/kit/vite'
-import * as fs from 'fs'
 import type { UserConfig } from 'vite'
 
-// fetch latest Matbench Discovery metrics table at build time and save to src/ dir
-await fetch(
-  `https://github.com/janosh/matbench-discovery/raw/main/site/src/figs/metrics-table-uniq-protos.svelte`,
-)
-  .then((res) => res.text())
-  .then((text) => {
-    fs.writeFileSync(`src/MetricsTable.svelte`, text)
-  })
-
 export default {
   plugins: [sveltekit()],
 
diff --git a/tests/test_md.py b/tests/test_md.py
index ec62f632..f44c21eb 100644
--- a/tests/test_md.py
+++ b/tests/test_md.py
@@ -2,7 +2,7 @@
 
 import os
 import pickle
-from typing import TYPE_CHECKING, Literal
+from typing import TYPE_CHECKING, Literal, get_args
 
 import numpy as np
 import pytest
@@ -22,7 +22,7 @@
 from chgnet.graph import CrystalGraphConverter
 from chgnet.model import StructOptimizer
 from chgnet.model.dynamics import CHGNetCalculator, EquationOfState, MolecularDynamics
-from chgnet.model.model import CHGNet
+from chgnet.model.model import CHGNet, PredTask
 
 if TYPE_CHECKING:
     from pathlib import Path
@@ -314,3 +314,27 @@ def test_md_crystal_feas_log(tmp_path: Path, monkeypatch: MonkeyPatch):
     assert crystal_feas[0][1] == approx(-1.4285042, abs=1e-5)
     assert crystal_feas[10][0] == approx(-0.0020592688, abs=1e-5)
     assert crystal_feas[10][1] == approx(-1.4284436, abs=1e-5)
+
+
+@pytest.mark.parametrize("task", [*get_args(PredTask)])
+def test_calculator_task_valid(task: PredTask):
+    """Test that the task kwarg of CHGNetCalculator.calculate() works correctly."""
+    key_map = dict(e="energy", f="forces", m="magmoms", s="stress")
+    calculator = CHGNetCalculator()
+    atoms = AseAtomsAdaptor.get_atoms(structure)
+    atoms.calc = calculator
+
+    calculator.calculate(atoms=atoms, task=task)
+
+    for key, prop in key_map.items():
+        assert (prop in calculator.results) == (key in task)
+
+
+def test_calculator_task_invalid():
+    """Test that invalid task raises ValueError."""
+    calculator = CHGNetCalculator()
+    atoms = AseAtomsAdaptor.get_atoms(structure)
+    atoms.calc = calculator
+
+    with pytest.raises(ValueError, match="Invalid task='invalid'."):
+        calculator.calculate(atoms=atoms, task="invalid")
diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py
index c23b675b..b5d39fff 100644
--- a/tests/test_relaxation.py
+++ b/tests/test_relaxation.py
@@ -50,7 +50,9 @@ def test_relaxation(
     assert {*traj.__dict__} == {
         *"atoms energies forces stresses magmoms atom_positions cells".split()
     }
-    assert len(traj) == 2 if algorithm == "legacy" else 4
+    assert len(traj) == (
+        2 if algorithm == "legacy" else 4
+    ), f"{len(traj)=}, {algorithm=}"
 
     # make sure final structure is more relaxed than initial one
     assert traj.energies[-1] == pytest.approx(-58.94209, rel=1e-4)
diff --git a/tests/test_trainer.py b/tests/test_trainer.py
index de1d1497..db769dc2 100644
--- a/tests/test_trainer.py
+++ b/tests/test_trainer.py
@@ -74,9 +74,9 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
     for param in chgnet.composition_model.parameters():
         assert param.requires_grad is False
     assert tmp_path.is_dir(), "Training dir was not created"
-    for target_str in ["e", "f", "s", "m"]:
-        assert ~np.isnan(trainer.training_history[target_str]["train"]).any()
-        assert ~np.isnan(trainer.training_history[target_str]["val"]).any()
+    for prop in "efsm":
+        assert ~np.isnan(trainer.training_history[prop]["train"]).any()
+        assert ~np.isnan(trainer.training_history[prop]["val"]).any()
     output_files = [file.name for file in tmp_path.iterdir()]
     for prefix in ("epoch", "bestE_", "bestF_"):
         n_matches = sum(file.startswith(prefix) for file in output_files)