Skip to content

Commit

Permalink
add scripts/ctk_structure_viewer.py
Browse files Browse the repository at this point in the history
fix links to data files in site/src/routes/contribute/+page.md
  • Loading branch information
janosh committed Jun 20, 2023
1 parent fbe847b commit fa1a439
Show file tree
Hide file tree
Showing 20 changed files with 117 additions and 35 deletions.
2 changes: 1 addition & 1 deletion data/mp/build_phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
json.dump(elemental_ref_entries, file, default=lambda x: x.as_dict())


df_mp = pd.read_json(DATA_FILES.mp_energies).set_index("material_id")
df_mp = pd.read_csv(DATA_FILES.mp_energies).set_index("material_id")


# %%
Expand Down
2 changes: 1 addition & 1 deletion data/wbm/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@


# %% load MP training set
df = pd.read_json(DATA_FILES.mp_energies)
df = pd.read_csv(DATA_FILES.mp_energies)
mp_elem_counts = count_elements(df.formula_pretty).astype(int)

# mp_elem_counts.to_json(f"{about_data_page}/mp-element-counts.json")
Expand Down
2 changes: 1 addition & 1 deletion data/wbm/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Element counts for MP training set consisting of 146,323 `ComputedStructureEntri

## 🎯   Target Distribution

The WBM test set has an energy above the MP convex hull distribution with mean **0.02 eV/atom** and standard deviation of **0.25 eV/atom**.
The WBM test set has an energy above the MP convex hull distribution with **mean ± std = 0.02 ± 0.25 eV/atom**.

The dummy MAE of always predicting the test set mean is **0.17 eV/atom**.

Expand Down
6 changes: 3 additions & 3 deletions matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ class PredFiles(Files):
bowsr_megnet = "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv"
# default CHGNet model from publication with 400,438 params
chgnet = "chgnet/2023-03-06-chgnet-wbm-IS2RE.csv"
# chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"
chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"
# CGCnn 10-member ensemble
cgcnn = "cgcnn/2023-01-26-test-cgcnn-wbm-IS2RE/cgcnn-ensemble-preds.csv"
# cgcnn 10-member ensemble with 5-fold training set perturbations
cgcnn_p = "cgcnn/2023-02-05-cgcnn-perturb=5.csv"
# original m3gnet straight from publication, not re-trained
m3gnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
# m3gnet-relaxed structures fed into megnet for formation energy prediction
# m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
# original megnet straight from publication, not re-trained
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
# magpie composition+voronoi tessellation structure features + sklearn random forest
Expand Down Expand Up @@ -123,7 +123,7 @@ def load_df_wbm_preds(
)

# pick F1 as primary metric to sort by
df_metrics = df_metrics.round(3).sort_values("F1", axis=1)
df_metrics = df_metrics.round(3).sort_values("F1", axis=1, ascending=False)

# dataframe of all models' energy above convex hull (EACH) predictions (eV/atom)
df_each_pred = pd.DataFrame()
Expand Down
2 changes: 2 additions & 0 deletions matbench_discovery/structure.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np
from pymatgen.core import Structure

Expand Down
7 changes: 4 additions & 3 deletions models/cgcnn/test_cgcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"IS2RE": DATA_FILES.wbm_initial_structures,
"RS2RE": DATA_FILES.wbm_computed_structure_entries,
"IS2RE-debug": f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json-1k-samples.bz2",
}[task_type + "-debug" if DEBUG else ""]
}[task_type + ("-debug" if DEBUG else "")]
input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}[task_type]

df = pd.read_json(data_path).set_index("material_id")
Expand All @@ -69,9 +69,10 @@
"created_at": {"$gt": "2023-01-09", "$lt": "2023-01-10"},
}
runs = wandb.Api().runs(WANDB_PATH, filters=filters)
expected_runs = 10
assert (
len(runs) == 10
), f"Expected 10 runs, got {len(runs)} filtering {WANDB_PATH=} with {filters=}"
len(runs) == expected_runs
), f"{expected_runs=}, got {len(runs)} filtering {WANDB_PATH=} with {filters=}"

for idx, run in enumerate(runs):
for key, val in run.config.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
df_m3gnet_is2re.initial_structure.map(lambda x: x["lattice"])
).add_prefix("m3gnet_")
df_m3gnet_is2re[df_m3gnet_lattice.columns] = df_m3gnet_lattice.to_numpy()
df_m3gnet_is2re

# df_m3gnet_is2re["m3gnet_energy"] = df_m3gnet_is2re.trajectory.map(
# lambda x: x["energies"][-1][0]
Expand Down
2 changes: 1 addition & 1 deletion models/voronoi/train_test_voronoi_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
df_train = glob_to_df(train_path).set_index("material_id")
print(f"{df_train.shape=}")

df_mp = pd.read_json(DATA_FILES.mp_energies).set_index("material_id")
df_mp = pd.read_csv(DATA_FILES.mp_energies).set_index("material_id")
train_e_form_col = "formation_energy_per_atom"

test_path = f"{module_dir}/2022-11-18-features-wbm-{task_type}.csv.bz2"
Expand Down
5 changes: 3 additions & 2 deletions models/wrenformer/test_wrenformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@
"display_name": {"$regex": "wrenformer-robust"},
}
runs = wandb.Api().runs(WANDB_PATH, filters=filters)
expected_runs = 10
assert (
len(runs) == 10
), f"Expected 10 runs, got {len(runs)} filtering {WANDB_PATH=} with {filters=}"
len(runs) == expected_runs
), f"{expected_runs=}, got {len(runs)} filtering {WANDB_PATH=} with {filters=}"

for idx, run in enumerate(runs):
for key, val in run.config.items():
Expand Down
2 changes: 1 addition & 1 deletion scripts/compile_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
),
),
"CHGNet": dict(
n_runs=102,
n_runs=100,
filters=dict(
display_name={"$regex": "chgnet-wbm-IS2RE-"},
created_at={"$gt": "2023-03-05", "$lt": "2023-03-07"},
Expand Down
72 changes: 72 additions & 0 deletions scripts/ctk_structure_viewer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

import pandas as pd
from crystal_toolkit.helpers.utils import hook_up_fig_with_struct_viewer

__author__ = "Janosh Riebesell"
__date__ = "2023-03-07"

"""
This scripts runs a Crystal Toolkit app that shows a scatter plot of CHGNet energies
and allows to click on points to view the corresponding structures. Run with:
python scripts/ctk_structure_viewer.py
Then open http://localhost:8000 in your browser.
"""

df_plot = None
min_e_diff = 0.1
e_form_2000 = "e_form_per_atom_chgnet_2000"
e_form_500 = "e_form_per_atom_chgnet_500"

if df_plot is None:
from matbench_discovery.preds import PRED_FILES

df_chgnet = pd.read_json(PRED_FILES.__dict__["CHGNet"].replace(".csv", ".json.gz"))
df_chgnet = df_chgnet.set_index("material_id")

df_chgnet_2000 = pd.read_csv(PRED_FILES.__dict__["CHGNet"])
df_chgnet_2000 = df_chgnet_2000.set_index("material_id").add_suffix("_2000")
df_chgnet[list(df_chgnet_2000)] = df_chgnet_2000

df_chgnet_500 = pd.read_csv(PRED_FILES.__dict__["CHGNet"].replace("-06", "-04"))
df_chgnet_500 = df_chgnet_500.set_index("material_id").add_suffix("_500")
df_chgnet[list(df_chgnet_500)] = df_chgnet_500

e_form_abs_diff = "e_form_abs_diff"
df_chgnet[e_form_abs_diff] = abs(df_chgnet[e_form_2000] - df_chgnet[e_form_500])
df_plot = df_chgnet.round(3).query(f"{e_form_abs_diff} > {min_e_diff}")


plot_labels = {
e_form_500: "CHGNet E<sub>form</sub> after 500 steps",
e_form_2000: "CHGNet E<sub>form</sub> after 2000 steps",
e_form_abs_diff: "Δ E<sub>form</sub>",
}

fig = df_plot.reset_index().plot.scatter(
x=e_form_500,
y=e_form_2000,
backend="plotly",
hover_name="material_id",
hover_data=["formula"],
labels=plot_labels,
size=e_form_abs_diff,
color=e_form_abs_diff,
template="plotly_white",
)

fig.layout.margin.update(b=20, l=40, r=20, t=50)
fig.layout.coloraxis.colorbar.update(
title=dict(text="Energy Diff (eV/atom)", side="right"), thickness=10
)
# slightly increase scatter point size (lower sizeref means larger)
fig.update_traces(marker_sizeref=0.02, selector=dict(mode="markers"))

app = hook_up_fig_with_struct_viewer(
fig,
df_plot,
"chgnet_structure",
# validate_id requires material_id to be hover_name
validate_id=lambda id: id.startswith(("wbm-", "mp-", "mvc-")),
)
app.run_server(debug=True, port=8000)
2 changes: 1 addition & 1 deletion scripts/difficult_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# %%
n_rows, n_cols = 5, 4
for which in ("best", "worst"):
fig, axs = plt.subplots(n_rows, n_cols, figsize=(3 * n_rows, 4 * n_cols))
fig, axs = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
n_axs = len(axs.flat)

errs = (
Expand Down
5 changes: 3 additions & 2 deletions scripts/rolling_mae_vs_hull_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@


# %%
# model = "Wrenformer"
model = "Wrenformer"
model = "M3GNet + MEGNet"
model = "MEGNet"
model = "MEGNet Old"
model = "CHGNet"

ax, df_err, df_std = rolling_mae_vs_hull_dist(
e_above_hull_true=df_wbm[each_true_col],
e_above_hull_errors={model: df_wbm[e_form_col] - df_wbm[model]},
Expand Down
3 changes: 2 additions & 1 deletion site/src/app.css
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,9 @@ input {
font-size: 12pt;
background: rgba(255, 255, 255, 0.1);
color: var(--text-color);
transition: background 0.2s;
}
input:focus {
outline: none;
background: rgba(255, 255, 255, 0.15);
background: rgba(255, 255, 255, 0.2);
}
3 changes: 2 additions & 1 deletion site/src/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ export type Citation = {
title: string
subtitle?: string
authors: {
name: string
'family-names': string
'given-names': string
affiliation: string
affil_key: string
orcid: string
Expand Down
14 changes: 7 additions & 7 deletions site/src/routes/contribute/+page.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ assert list(df_wbm) == [

You can also download the data files directly from GitHub:

1. [`2022-10-19-wbm-summary.csv`]({repo}/raw/v1.0.0/data/wbm/2022-10-19-wbm-summary.csv) [[GitHub]({repo}/blob/v1/data/wbm/2022-10-19-wbm-summary.csv)]: Computed material properties only, no structures. Available properties are VASP energy, formation energy, energy above the convex hull, volume, band gap, number of sites per unit cell, and more. e_form_per_atom and e_above_hull each have 3 separate columns for old, new and no Materials
1. [`2022-10-19-wbm-init-structs.json`]({repo}/raw/v1.0.0/data/wbm/2022-10-19-wbm-init-structs.json) [[GitHub]({repo}/blob/v1/data/wbm/2022-10-19-wbm-init-structs.json)]: Unrelaxed WBM structures
1. [`2022-10-19-wbm-cses.json`]({repo}/raw/v1.0.0/data/wbm/2022-10-19-wbm-cses.json) [[GitHub]({repo}/blob/v1/data/wbm/2022-10-19-wbm-cses.json)]: Relaxed WBM structures along with final VASP energies
1. [`2023-01-10-mp-energies.json.gz`]({repo}/raw/v1.0.0/data/wbm/2023-01-10-mp-energies.json.gz) [[GitHub]({repo}/blob/v1/data/wbm/2023-01-10-mp-energies.json.gz)]: Materials Project formation energies and energies above convex hull
1. [`2023-02-07-mp-computed-structure-entries.json.gz`]({repo}/raw/v1.0.0/data/wbm/2023-02-07-mp-computed-structure-entries.json.gz) [[GitHub]({repo}/blob/v1/data/wbm/2023-02-07-mp-computed-structure-entries.json.gz)]: Materials Project computed structure entries
1. [`2023-02-07-ppd-mp.pkl.gz`]({repo}/raw/v1.0.0/data/wbm/2023-02-07-ppd-mp.pkl.gz) [[GitHub]({repo}/blob/v1/data/wbm/2023-02-07-ppd-mp.pkl.gz)]: [PatchedPhaseDiagram](https://pymatgen.org/pymatgen.analysis.phase_diagram.html#pymatgen.analysis.phase_diagram.PatchedPhaseDiagram) constructed from all MP ComputedStructureEntries
1. [`2022-09-19-mp-elemental-ref-energies.json`]({repo}/raw/v1.0.0/data/wbm/2022-09-19-mp-elemental-ref-energies.json) [[GitHub]({repo}/blob/v1/data/wbm/2022-09-19-mp-elemental-ref-energies.json)]: Minimum energy PDEntries for each element present in the Materials Project
1. [`2022-10-19-wbm-summary.csv`]({repo}/blob/-/data/wbm/2022-10-19-wbm-summary.csv): Computed material properties only, no structures. Available properties are VASP energy, formation energy, energy above the convex hull, volume, band gap, number of sites per unit cell, and more. e_form_per_atom and e_above_hull each have 3 separate columns for old, new and no Materials
1. [`2022-10-19-wbm-init-structs.json`]({repo}/blob/-/data/wbm/2022-10-19-wbm-init-structs.json): Unrelaxed WBM structures
1. [`2022-10-19-wbm-cses.json`]({repo}/blob/-/data/wbm/2022-10-19-wbm-cses.json): Relaxed WBM structures along with final VASP energies
1. [`2023-01-10-mp-energies.json.gz`]({repo}/blob/-/data/mp/2023-01-10-mp-energies.json.gz): Materials Project formation energies and energies above convex hull
1. [`2023-02-07-mp-computed-structure-entries.json.gz`]({repo}/blob/-/data/mp/2023-02-07-mp-computed-structure-entries.json.gz): Materials Project computed structure entries
1. [`2023-02-07-ppd-mp.pkl.gz`]({repo}/blob/-/data/mp/2023-02-07-ppd-mp.pkl.gz): [PatchedPhaseDiagram](https://pymatgen.org/pymatgen.analysis.phase_diagram.html#pymatgen.analysis.phase_diagram.PatchedPhaseDiagram) constructed from all MP ComputedStructureEntries
1. [`2022-09-19-mp-elemental-reference-entries.json`]({repo}/blob/-/data/mp/2022-09-19-mp-elemental-reference-entries.json): Minimum energy PDEntries for each element present in the Materials Project

[wbm paper]: https://nature.com/articles/s41524-020-00481-6

Expand Down
3 changes: 1 addition & 2 deletions site/src/routes/models/+page.server.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import type { ModelData } from '$lib'
import { compile } from 'mdsvex'
import { dirname } from 'path'
import type { PageServerLoad } from './$types'
import model_stats from './model-stats.json'

export const load: PageServerLoad = async () => {
export const load = async () => {
const yml = import.meta.glob(`$root/models/**/metadata.yml`, {
eager: true,
})
Expand Down
10 changes: 8 additions & 2 deletions site/src/routes/models/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import { RadioButtons, Tooltip } from 'svelte-zoo'
import { flip } from 'svelte/animate'
import { fade } from 'svelte/transition'
import type { PageData, Snapshot } from './$types'
import type { Snapshot } from './$types'
export let data: PageData
export let data
let sort_by: keyof ModelStats | 'model_name' = `F1`
let show_details = false
Expand Down Expand Up @@ -125,4 +125,10 @@
span :global(div.zoo-radio-btn span) {
padding: 1pt 4pt;
}
input[type='number'] {
text-align: center;
}
input[type='number']::-webkit-inner-spin-button {
display: none;
}
</style>
6 changes: 3 additions & 3 deletions site/src/routes/paper/+layout.server.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import type { Citation } from '$lib'
import fs from 'fs'
import yml from 'js-yaml'
import type { LayoutServerLoad } from './$types'

export const load: LayoutServerLoad = async ({ route }) => {
export const load = async ({ route }) => {
const data = fs.readFileSync(`src/routes/${route.id}/+page.md`, `utf8`)
const cff = fs.readFileSync(`../citation.cff`, `utf8`)

// Count the number of words using a regular expression
const word_count = data.match(/\b\w+\b/g)?.length ?? null

return { word_count, ...yml.load(cff) }
return { word_count, ...(yml.load(cff) as Citation) }
}
3 changes: 1 addition & 2 deletions site/src/routes/paper/+layout.svelte
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
<script lang="ts">
import { References } from '$lib'
import { pretty_num } from 'elementari/labels'
import type { LayoutServerData } from './$types'
import { references } from './references.yaml'
export let data: LayoutServerData
export let data
const authors = data.authors.map(
(auth) => `${auth[`given-names`]} ${auth[`family-names`]}<sup>${auth.affil_key}</sup>`
Expand Down

0 comments on commit fa1a439

Please sign in to comment.