Skip to content

Commit

Permalink
add scatter plot of largest errors averaged over models vs DFT hull d…
Browse files Browse the repository at this point in the history
…istance

improve rolling-mae-vs-hull-dist-wbm-batches-models caption

update deps

move metric table rows showing MEGNet combos with M3GNet and CHNGet to SI

add prop hide: string[] to metrics table to hide rows with matching headers
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 8a0c3bb commit 457abf0
Show file tree
Hide file tree
Showing 25 changed files with 336 additions and 281 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.260
rev: v0.0.261
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -57,7 +57,7 @@ repos:
- prettier
- prettier-plugin-svelte
- svelte
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yml|yaml|json))$
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yaml|json))$

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v8.37.0
Expand Down
2 changes: 1 addition & 1 deletion data/wbm/eda.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@
# weight in each WBM composition. TLDR: no obvious structure in the data
# was hoping to find certain clusters to have higher or lower errors after seeing
# many models struggle on the halogens in per-element error periodic table heatmaps
# https://matbench-discovery.janosh.dev/models
# https://janosh.github.io/matbench-discovery/models
df_2d_tsne = pd.read_csv(f"{module_dir}/tsne/one-hot-112-composition-2d.csv.gz")
df_2d_tsne = df_2d_tsne.set_index("material_id")

Expand Down
17 changes: 11 additions & 6 deletions matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from matbench_discovery import ROOT
from matbench_discovery.data import Files, glob_to_df
from matbench_discovery.metrics import stable_metrics
from matbench_discovery.plots import model_labels
from matbench_discovery.plots import eVpa, model_labels, quantity_labels

"""Centralize data-loading and computing metrics for plotting scripts"""

Expand All @@ -19,7 +19,12 @@
e_form_col = "e_form_per_atom_mp2020_corrected"
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
each_pred_col = "e_above_hull_pred"
model_mean_err_col = "Mean over models"
model_mean_err_col = "Mean error all models"
model_std_col = "Std. dev. over models"


quantity_labels[model_mean_err_col] = f"{model_mean_err_col} {eVpa}"
quantity_labels[model_std_col] = f"{model_std_col} {eVpa}"


class PredFiles(Files):
Expand All @@ -34,8 +39,6 @@ class PredFiles(Files):
bowsr_megnet = "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv"
# default CHGNet model from publication with 400,438 params
chgnet = "chgnet/2023-03-06-chgnet-wbm-IS2RE.csv"
# CHGNet-relaxed structures fed into MEGNet for formation energy prediction
# chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"

# CGCnn 10-member ensemble
cgcnn = "cgcnn/2023-01-26-test-cgcnn-wbm-IS2RE/cgcnn-ensemble-preds.csv"
Expand All @@ -44,11 +47,13 @@ class PredFiles(Files):

# original M3GNet straight from publication, not re-trained
m3gnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
# M3GNet-relaxed structures fed into MEGNet for formation energy prediction
# m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"

# original MEGNet straight from publication, not re-trained
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
# CHGNet-relaxed structures fed into MEGNet for formation energy prediction
# chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"
# M3GNet-relaxed structures fed into MEGNet for formation energy prediction
# m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"

# Magpie composition+Voronoi tessellation structure features + sklearn random forest
voronoi_rf = "voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv"
Expand Down
4 changes: 2 additions & 2 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ Matbench Discovery
Matbench Discovery is an [interactive leaderboard](https://janosh.github.io/matbench-discovery) and associated [PyPI package](https://pypi.org/project/matbench-discovery) which together make it easy to benchmark ML energy models on a task designed to closely simulate a high-throughput discovery campaign for new stable inorganic crystals.

In version 1 of this benchmark, we explore 8 models covering multiple methodologies ranging from random forests to graph neural networks, from one-shot predictors to iterative Bayesian optimizers and interatomic potential-based relaxers. We find [M3GNet](https://github.com/materialsvirtuallab/m3gnet) ([paper](https://doi.org/10.1038/s43588-022-00349-3)) to achieve the highest F1 score of 0.58 and $R^2$ of 0.59 while [MEGNet](https://github.com/materialsvirtuallab/megnet) ([paper](https://doi.org/10.1021/acs.chemmater.9b01294)) wins on discovery acceleration factor (DAF) with 2.94. See the [**full results**](https://matbench-discovery.janosh.dev/paper#results) in our interactive dashboard which provides valuable insights for maintainers of large-scale materials databases. We show these models have become powerful enough to warrant deploying them as triaging steps to more effectively allocate compute in high-throughput DFT relaxations.
In version 1 of this benchmark, we explore 8 models covering multiple methodologies ranging from random forests to graph neural networks, from one-shot predictors to iterative Bayesian optimizers and interatomic potential-based relaxers. We find [CHGNet](https://github.com/CederGroupHub/chgnet) ([paper](https://doi.org/10.48550/arXiv.2302.14231)) to achieve the highest F1 score of 0.59, $R^2$ of 0.61 and a discovery acceleration factor (DAF) of 3.06 (meaning a 3x higher rate of stable structures compared to dummy selection in our already enriched search space). See the [**full results**](https://janosh.github.io/matbench-discovery/paper#results) in our interactive dashboard which provides valuable insights for maintainers of large-scale materials databases. We show these models have become powerful enough to warrant deploying them as triaging steps to more effectively allocate compute in high-throughput DFT relaxations.

<slot name="metrics-table" />

We welcome contributions that add new models to the leaderboard through [GitHub PRs](https://github.com/janosh/matbench-discovery/pulls). See the [usage and contributing guide](https://janosh.github.io/matbench-discovery/contribute) for details.

For a version 2 release of this benchmark, we plan to merge the current training and test sets into the new training set and acquire a much larger test set (potentially at meta-GGA level of theory) compared to the v1 test set of 257k structures. Anyone interested in joining this effort please [open a GitHub discussion](https://github.com/janosh/matbench-discovery/discussions) or [reach out privately](mailto:[email protected]?subject=Matbench%20Discovery).

For detailed results and analysis, check out the [paper](https://matbench-discovery.janosh.dev/paper) and [supplementary material](https://matbench-discovery.janosh.dev/si).
For detailed results and analysis, check out the [paper](https://janosh.github.io/matbench-discovery/paper) and [supplementary material](https://janosh.github.io/matbench-discovery/si).
4 changes: 1 addition & 3 deletions scripts/analyze_element_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,8 @@
# %% plot EACH errors against least prevalent element in structure (by occurrence in
# MP training set). this seems to correlate more with model error
n_examp_for_rarest_elem_col = "Examples for rarest element in structure"
df_wbm["composition"] = df_wbm.get("composition", df_wbm.formula.map(Composition))
df_elem_err.loc[list(map(str, df_wbm.composition[0]))][train_count_col].min()
df_wbm[n_examp_for_rarest_elem_col] = [
df_elem_err.loc[list(map(str, Composition(formula)))][train_count_col].min()
df_elem_err[train_count_col].loc[list(map(str, Composition(formula)))].min()
for formula in tqdm(df_wbm.formula)
]

Expand Down
83 changes: 59 additions & 24 deletions scripts/analyze_model_failure_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@
df_preds,
each_true_col,
model_mean_err_col,
model_std_col,
)

__author__ = "Janosh Riebesell"
__date__ = "2023-02-15"

models = list(df_each_pred)
df_preds[model_std_col] = df_preds[models].std(axis=1)
df_each_err[model_mean_err_col] = df_preds[model_mean_err_col] = df_each_err.abs().mean(
axis=1
)
Expand Down Expand Up @@ -158,29 +161,53 @@
# save_fig(fig, f"{FIGS}/scatter-largest-each-errors-fp-diff-models.svelte")


# %% plotly scatter plot of largest model errors with points sized by mean error and
# colored by true stability.
# while some points lie on a horizontal line of constant error, more follow the identity
# line suggesting the models failed to learn the true physics in these materials
fig = df_preds.nlargest(200, model_mean_err_col).plot.scatter(
x=each_true_col,
y=model_mean_err_col,
color=each_true_col,
size=model_mean_err_col,
backend="plotly",
# %%
df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index("material_id")
train_count_col = "MP Occurrences"
df_elem_counts = count_elements(df_mp.formula_pretty, count_mode="occurrence").to_frame(
name=train_count_col
)
fig.layout.coloraxis.colorbar.update(
title="DFT distance to convex hull (eV/atom)",
title_side="top",
yanchor="bottom",
y=1,
xanchor="center",
x=0.5,
orientation="h",
thickness=12,
n_examp_for_rarest_elem_col = "Examples for rarest element in structure"
df_wbm[n_examp_for_rarest_elem_col] = [
df_elem_counts[train_count_col].loc[list(map(str, Composition(formula)))].min()
for formula in tqdm(df_wbm.formula)
]
df_preds[n_examp_for_rarest_elem_col] = df_wbm[n_examp_for_rarest_elem_col]


# %% scatter plot of largest model errors vs. DFT hull distance
# while some points lie on a horizontal line of constant error, more follow the identity
# line showing models are biased to predict low energies likely as a result of training
# on MP which is highly low-energy enriched.
# also possible models failed to learn whatever physics makes these materials highly
# unstable
fig = (
df_preds.nlargest(200, model_mean_err_col)
.round(2)
.plot.scatter(
x=each_true_col,
y=model_mean_err_col,
color=model_std_col,
size=n_examp_for_rarest_elem_col,
backend="plotly",
hover_name="material_id",
hover_data=["formula"],
color_continuous_scale="Turbo",
)
)
# yanchor="bottom", y=1, xanchor="center", x=0.5, orientation="h", thickness=12
fig.layout.coloraxis.colorbar.update(title_side="right", thickness=14)
add_identity_line(fig)
fig.layout.title = (
"Largest model errors vs. DFT hull distance colored by model disagreement"
)
# tried setting error_y=model_std_col but looks bad
# fig.update_traces(error_y=dict(color="rgba(255,255,255,0.2)", width=3, thickness=2))
fig.show()
# save_fig(fig, f"{FIGS}/scatter-largest-errors-models-mean-vs-each-true.svelte")
# save_fig(
# fig, f"{ROOT}/tmp/figures/scatter-largest-errors-models-mean-vs-each-true.pdf"
# )


# %% find materials that were misclassified by all models
Expand All @@ -203,16 +230,24 @@


# %%
normalized = True
elem_counts: dict[str, pd.Series] = {}
for col in ("All models false neg", "All models false pos"):
elem_counts[col] = elem_counts.get(
col, count_elements(df_preds[df_preds[col]].formula)
)
fig = ptable_heatmap_plotly(elem_counts[col], font_size=10)
fig.layout.title = col
fig.layout.margin.update(l=0, r=0, t=50, b=0)
fig = ptable_heatmap_plotly(
elem_counts[col] / df_elem_counts[train_count_col]
if normalized
else elem_counts[col],
color_bar=dict(title=col),
precision=".3f",
cscale_range=[0, 0.1],
)
fig.show()

# TODO plot these for each model individually


# %% map abs EACH model errors onto elements in structure weighted by composition
# fraction and average over all test set structures
Expand All @@ -234,8 +269,8 @@
# df_frac_comp = df_frac_comp.dropna(axis=1, thresh=100) # remove Xe with only 1 entry


# %% TODO investigate if structures with largest mean over models error can be
# attributed to DFT gone wrong. would be cool if models can be run across large
# %% TODO investigate if structures with largest mean error across all models error can
# be attributed to DFT gone wrong. would be cool if models can be run across large
# databases as correctness checkers
df_each_err.abs().mean().sort_values()
df_each_err.abs().mean(axis=1).nlargest(25)
Expand Down
16 changes: 10 additions & 6 deletions scripts/compile_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,15 @@

# %%
time_cols = list(df_stats.filter(like=time_col))
for col in time_cols: # uncomment to include run times in metrics table
df_metrics.loc[col] = df_stats[col]
# for col in time_cols: # uncomment to include run times in metrics table
# df_metrics.loc[col] = df_stats[col]
higher_is_better = {"DAF", "R²", "Precision", "F1", "Accuracy", "TPR", "TNR"}
lower_is_better = {"MAE", "RMSE", "FNR", "FPR", *time_cols}
lower_is_better = {"MAE", "RMSE", "FNR", "FPR"}
df_metrics = df_metrics.rename(index={"R2": "R²"})
idx_set = set(df_metrics.index)

styler = (
df_metrics.T.rename(columns={"R2": "R²"})
df_metrics.T
# append arrow up/down to table headers to indicate higher/lower metric is better
# .rename(columns=lambda x: x + " ↑" if x in higher_is_better else x + " ↓")
.style.format(precision=2)
Expand All @@ -141,10 +143,12 @@
styler.hide(["Recall", "FPR", "FNR"], axis=1)


# %% export model metrics as styled HTML table
# %% export model metrics as styled HTML table and Svelte component
styler.to_html(f"{ROOT}/tmp/figures/model-metrics.html")

# insert svelte {...props} forwarding to the table element
insert = """
<script>
<script lang="ts">
import { sortable } from 'svelte-zoo/actions'
</script>
Expand Down
2 changes: 1 addition & 1 deletion scripts/cumulative_clf_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# fig.suptitle(title)
fig.text(0.5, -0.08, x_label, ha="center", fontdict={"size": 16})
if backend == "plotly":
fig.layout.legend.update(x=0.01, y=0.01, bgcolor="rgba(0,0,0,0)")
fig.layout.legend.update(x=0, y=0, bgcolor="rgba(0,0,0,0)")
fig.layout.margin.update(l=0, r=5, t=30, b=50)
fig.add_annotation(
x=0.5,
Expand Down
7 changes: 3 additions & 4 deletions scripts/rolling_mae_vs_hull_dist_wbm_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# %%
from pymatviz.utils import save_fig

from matbench_discovery import FIGS, today
from matbench_discovery import FIGS, ROOT, today
from matbench_discovery.plots import plt, rolling_mae_vs_hull_dist
from matbench_discovery.preds import df_each_pred, df_preds, e_form_col, each_true_col

Expand All @@ -16,10 +16,10 @@
batch_col = "batch_idx"
df_each_pred[batch_col] = "Batch " + df_each_pred.index.str.split("-").str[1]
df_err, df_std = None, None # variables to cache rolling MAE and std
model = "MEGNet"


# %% matplotlib
model = "Wrenformer"
fig, ax = plt.subplots(1, figsize=(10, 9))
markers = ("o", "v", "^", "H", "D")
assert len(markers) == 5 # number of iterations of element substitution in WBM data set
Expand Down Expand Up @@ -54,7 +54,6 @@


# %% plotly
model = "CHGNet"
df_pivot = df_each_pred.pivot(columns=batch_col, values=model)

# unstack two-level column index into new model column
Expand All @@ -81,4 +80,4 @@
file_model = model.lower().replace(" + ", "-").replace(" ", "-")
img_path = f"{file_model}-rolling-mae-vs-hull-dist-wbm-batches"
save_fig(fig, f"{FIGS}/{img_path}.svelte")
# save_fig(f"{ROOT}/tmp/figures/{img_path}.pdf")
save_fig(fig, f"{ROOT}/tmp/figures/{img_path}.pdf")
16 changes: 8 additions & 8 deletions site/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
"@iconify/svelte": "^3.1.0",
"@rollup/plugin-yaml": "^4.0.1",
"@sveltejs/adapter-static": "^2.0.1",
"@sveltejs/kit": "^1.14.0",
"@sveltejs/vite-plugin-svelte": "^2.0.3",
"@typescript-eslint/eslint-plugin": "^5.57.0",
"@typescript-eslint/parser": "^5.57.0",
"@sveltejs/kit": "^1.15.1",
"@sveltejs/vite-plugin-svelte": "^2.0.4",
"@typescript-eslint/eslint-plugin": "^5.57.1",
"@typescript-eslint/parser": "^5.57.1",
"elementari": "^0.1.5",
"eslint": "^8.37.0",
"eslint-plugin-svelte3": "^4.0.0",
Expand All @@ -36,15 +36,15 @@
"rehype-katex-svelte": "^1.1.2",
"rehype-slug": "^5.1.0",
"remark-math": "3.0.0",
"svelte": "^3.57.0",
"svelte-check": "^3.1.4",
"svelte": "^3.58.0",
"svelte-check": "^3.2.0",
"svelte-multiselect": "^8.6.0",
"svelte-preprocess": "^5.0.3",
"svelte-toc": "^0.5.4",
"svelte-zoo": "^0.4.3",
"svelte2tsx": "^0.6.10",
"svelte2tsx": "^0.6.11",
"tslib": "^2.5.0",
"typescript": "5.0.2",
"typescript": "5.0.3",
"vite": "^4.2.1"
},
"prettier": {
Expand Down
2 changes: 1 addition & 1 deletion site/src/app.html
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
></script>
<link rel="stylesheet" href="/prism-vsc-dark-plus.css" />
<!-- interactive plots -->
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<script src="https://cdn.plot.ly/plotly-2.20.0.min.js"></script>
<!-- math display -->
<link
rel="stylesheet"
Expand Down

Large diffs are not rendered by default.

Loading

0 comments on commit 457abf0

Please sign in to comment.