Skip to content

Commit

Permalink
add site/src/figs/2023-02-01-mp-elemental-ref-energies.svelte generat…
Browse files Browse the repository at this point in the history
…ed in data/wbm/analysis.py

add hyperparams and notes section to models/**/metadata.yml
convert figure captions from <figcaption> element to markdown blockquote (>)
add SVG link icons to figure captions (same as headings)
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 4cf5028 commit 9f43711
Show file tree
Hide file tree
Showing 23 changed files with 276 additions and 152 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: 22.12.0
rev: 23.1.0
hooks:
- id: black

Expand Down Expand Up @@ -56,7 +56,7 @@ repos:
- id: codespell
stages: [commit, commit-msg]
exclude_types: [csv, json, svg]
exclude: ^(.+references.yaml)$
exclude: ^(.+references.yaml|site/src/figs/.+)$

- repo: https://github.com/PyCQA/autoflake
rev: v2.0.1
Expand Down
33 changes: 33 additions & 0 deletions data/wbm/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from matbench_discovery import FIGS, today
from matbench_discovery.data import df_wbm
from matbench_discovery.energy import mp_elem_reference_entries
from matbench_discovery.plots import pio

"""
Expand Down Expand Up @@ -129,3 +130,35 @@

save_fig(fig, f"{FIGS}/{today}-wbm-each-hist.svelte")
save_fig(fig, f"./figs/{today}-wbm-each-hist.svg", width=1000, height=500)


# %%
e_col, n_atoms_col = "Energy (eV/atom)", "Number of Atoms"
mp_ref_data = [
{
"Element": key,
e_col: ref.energy_per_atom,
n_atoms_col: ref.composition.num_atoms,
"Name": ref.composition.elements[0].long_name,
"Number": ref.composition.elements[0].number,
}
for key, ref in mp_elem_reference_entries.items()
]
df_ref = pd.DataFrame(mp_ref_data).sort_values("Number")


# %% plot MP elemental reference energies vs atomic number
# marker size = number of atoms in reference structure
fig = df_ref.round(2).plot.scatter(
x="Number", y=e_col, backend="plotly", hover_data=list(df_ref), size=n_atoms_col
)
fig.update_traces(mode="markers+lines")
fig.layout.margin = dict(l=0, r=0, t=0, b=0)

# add text annotations showing element symbols
for symbol, e_per_atom, *_, num in df_ref.itertuples(index=False):
fig.add_annotation(x=num, y=e_per_atom, text=symbol, showarrow=False, font_size=10)

fig.show()

save_fig(fig, f"{FIGS}/{today}-mp-elemental-ref-energies.svelte")
2 changes: 1 addition & 1 deletion data/wbm/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ The dummy MAE of always predicting the test set mean is **0.17 eV/atom**.

The number of stable materials (according to the MP convex hull which is spanned by the training data the models have access to) is **97k** out of **257k**, resulting in a dummy stability hit rate of **37%**.

> Incidentally, [according to the authors](https://www.nature.com/articles/s41524-020-00481-6#Sec2), a more accurate stability rate according to the combined MP+WBM convex hull of the first 3 rounds of elemental substitution is 18,479 out of 189,981 crystals ($\approx$ 9.7%).
> Note: [According to the authors](https://www.nature.com/articles/s41524-020-00481-6#Sec2), the stability rate w.r.t. to the more complete hull constructed from the combined train and test set (MP + WBM) for the first 3 rounds of elemental substitution is 18,479 out of 189,981 crystals ($\approx$ 9.7%).
<slot name="wbm-each-hist">
<img src="./figs/2023-01-26-wbm-each-hist.svg" alt="WBM energy above MP convex hull distribution">
Expand Down
4 changes: 2 additions & 2 deletions models/bowsr/join_bowsr_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@


# %%
out_path = f"{ROOT}/models/bowsr/{today}-bowsr-megnet-wbm-{task_type}.json.gz"
out_path = f"{module_dir}/{today}-bowsr-megnet-wbm-{task_type}.json.gz"
df_bowsr.reset_index().to_json(out_path, default_handler=lambda x: x.as_dict())

# save energy and formation energy as CSV for fast loading
df_bowsr.select_dtypes("number").to_csv(out_path.replace(".json.gz", ".csv"))

in_path = f"{ROOT}/models/bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.json.gz"
# in_path = f"{module_dir}/2023-01-23-bowsr-megnet-wbm-IS2RE.json.gz"
# df_bowsr = pd.read_json(in_path).set_index("material_id")
14 changes: 8 additions & 6 deletions models/bowsr/metadata.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ requirements:
pandas: 1.5.1
trained_on_benchmark: false

# model specific keys record hyperparameter choices
optimize_kwargs:
alpha: 0.000676
n_init: 100
n_iter: 100
task_type: IS2RE
hyperparams:
Optimizer Params:
alpha: 0.000676
n_init: 100
n_iter: 100

notes:
training: Uses same version of MEGNet as standalone MEGNet.
3 changes: 3 additions & 0 deletions models/cgcnn/metadata.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ requirements:
numpy: 1.24.0
pandas: 1.5.1
trained_on_benchmark: true

hyperparams:
Ensemble Size: 10
7 changes: 4 additions & 3 deletions models/m3gnet/join_m3gnet_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pymatgen.analysis.phase_diagram import PDEntry
from tqdm import tqdm

from matbench_discovery import ROOT, today
from matbench_discovery import today
from matbench_discovery.data import as_dict_handler
from matbench_discovery.energy import get_e_form_per_atom

Expand Down Expand Up @@ -52,10 +52,11 @@


# %%
out_path = f"{ROOT}/models/m3gnet/{today}-m3gnet-wbm-{task_type}.json.gz"
out_path = f"{module_dir}/{today}-m3gnet-wbm-{task_type}.json.gz"
df_m3gnet = df_m3gnet.round(5)
df_m3gnet.reset_index().to_json(out_path, default_handler=as_dict_handler)

df_m3gnet.select_dtypes("number").to_csv(out_path.replace(".json.gz", ".csv"))

# in_path = f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
# in_path = f"{module_dir}/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
# df_m3gnet = pd.read_json(in_path).set_index("material_id")
2 changes: 2 additions & 0 deletions models/m3gnet/metadata.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ requirements:
numpy: 1.24.0
pandas: 1.5.1
trained_on_benchmark: false
notes:
training: Using pre-trained model released with paper. Was only trained on a subset of 62,783 MP relaxation trajectories in the 2018 database release (see [related issue](https://github.com/materialsvirtuallab/m3gnet/issues/20#issuecomment-1207087219)).
6 changes: 4 additions & 2 deletions models/m3gnet/wbm_pre_vs_post_m3gnet_relaxation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# %%
import os

import pandas as pd
import plotly.express as px
from pymatgen.core import Structure
Expand All @@ -12,7 +14,7 @@
__author__ = "Janosh Riebesell"
__date__ = "2022-06-18"


module_dir = os.path.dirname(__file__)
del plots # https://github.com/PyCQA/pyflakes/issues/366


Expand Down Expand Up @@ -220,5 +222,5 @@
# %% write df back to compressed JSON
# filter out columns containing 'rs2re'
# df_m3gnet_is2re.reset_index().filter(regex="^((?!rs2re).)*$").to_json(
# f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE-2.json.gz"
# f"{module_dir}/2022-10-31-m3gnet-wbm-IS2RE-2.json.gz"
# ).set_index("material_id")
2 changes: 2 additions & 0 deletions models/megnet/metadata.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ requirements:
numpy: 1.24.0
pandas: 1.5.1
trained_on_benchmark: false
notes:
training: Using pre-trained model released with paper. Was only trained on `MP-crystals-2018.6.1` dataset [available on Figshare](https://figshare.com/articles/Graphs_of_materials_project/7451351).
3 changes: 3 additions & 0 deletions models/wrenformer/metadata.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ requirements:
numpy: 1.24.0
pandas: 1.5.1
trained_on_benchmark: true

hyperparams:
Ensemble Size: 10
2 changes: 0 additions & 2 deletions scripts/cumulative_clf_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@
# title=title,
legend=dict(yanchor="bottom", y=0.02, xanchor="right", x=0.9),
)
fig.update_xaxes(matches=None, showticklabels=True, title="")
fig.update_yaxes(matches=None, showticklabels=True, title="")
fig.layout.height = 500
fig.add_annotation(
x=0.5,
Expand Down
21 changes: 17 additions & 4 deletions scripts/prc_roc_curves_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pandas as pd
from pymatviz.utils import save_fig
from tqdm import tqdm

from matbench_discovery import FIGS, today
from matbench_discovery.data import load_df_wbm_preds
Expand Down Expand Up @@ -37,12 +38,12 @@
# %%
df_roc = pd.DataFrame()

for model in models:
for model in (pbar := tqdm(models)):
pbar.set_description(model)
df_wbm[f"{model}_{each_pred_col}"] = df_wbm[each_true_col] + (
df_wbm[model] - df_wbm[e_form_col]
)
for stab_treshold in np.arange(-0.4, 0.4, 0.01):

metrics = stable_metrics(
df_wbm[each_true_col], df_wbm[f"{model}_{each_pred_col}"], stab_treshold
)
Expand Down Expand Up @@ -74,11 +75,20 @@
fig.layout.coloraxis.colorbar.update(
x=1, y=1, xanchor="right", yanchor="top", thickness=14, len=0.27, title_side="right"
)
fig.add_annotation(text="No skill", x=0.5, y=0.5, showarrow=False, yshift=-10)
fig.add_shape(type="line", x0=0, y0=0, x1=1, y1=1, line=line, row="all", col="all")
fig.add_annotation(
text="No skill", x=0.5, y=0.5, showarrow=False, yshift=-10, textangle=-30
)
# allow scrolling and zooming each subplot individually
fig.update_xaxes(matches=None)
fig.update_yaxes(matches=None)
fig.show()


# %%
save_fig(fig, f"{FIGS}/{today}-roc-models.svelte")


# %%
fig = df_roc.plot.scatter(
x="Recall",
Expand All @@ -102,8 +112,11 @@
fig.add_annotation(
text="No skill", x=0, y=0.5, showarrow=False, xanchor="left", xshift=10, yshift=10
)
# allow scrolling and zooming each subplot individually
fig.update_xaxes(matches=None)
fig.update_yaxes(matches=None)
fig.show()


# %%
save_fig(fig, f"{FIGS}/{today}-roc-models.svelte")
save_fig(fig, f"{FIGS}/{today}-prc-models.svelte")
38 changes: 16 additions & 22 deletions site/src/app.css
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ tbody tr:nth-child(odd) {
h1 {
text-align: center;
}
blockquote {
border-left: 3pt solid var(--blue);
padding: 4pt 4pt 4pt 9pt;
margin: 1em auto;
background-color: rgba(255, 255, 255, 0.05);
}
blockquote > :is(:first-child, :last-child) {
margin: 0 auto;
}

:where(h2, h3, h4, h5, h6) {
scroll-margin-top: 50px;
Expand All @@ -114,26 +123,19 @@ h1 {
color: orange;
}
/* style heading anchors added by rehype-autolink-headings, see svelte.config.js */
:where(h2, h3, h4, h5, h6) a[aria-hidden='true'] {
:where(h2, h3, h4, h5, h6, strong) a[aria-hidden='true'] {
transition: 0.3s;
margin-left: 4pt;
opacity: 0;
}
:where(h2, h3, h4, h5, h6):hover a[aria-hidden='true'] {
:where(h2, h3, h4, h5, h6, strong):hover a[aria-hidden='true'] {
opacity: 1;
}

blockquote {
border-left: 3pt solid var(--blue);
padding: 4pt 2pt 4pt 9pt;
margin: 1em auto;
background-color: rgba(255, 255, 255, 0.1);
}
blockquote p:last-child {
margin-bottom: 0;
}
blockquote p:first-child {
margin-top: 0;
strong a[aria-hidden='true'] {
vertical-align: middle;
line-height: 1em;
position: absolute;
left: 0;
}

/* for /api/[slug] */
Expand Down Expand Up @@ -177,11 +179,3 @@ sup {
/* https://stackoverflow.com/a/6594576 */
line-height: 0;
}

figure {
display: block;
}
figcaption {
text-align: center;
margin: 1ex 1em;
}
1 change: 1 addition & 0 deletions site/src/figs/2023-02-02-mp-elemental-ref-energies.svelte

Large diffs are not rendered by default.

Loading

0 comments on commit 9f43711

Please sign in to comment.