Skip to content

Commit

Permalink
fix scripts/prc_roc_curves_models.py and roc-models.svelte fig
Browse files Browse the repository at this point in the history
sth wrong in stable_metrics() maybe?
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 5d98946 commit a42472c
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 100 deletions.
10 changes: 5 additions & 5 deletions data/mp/get_mp_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
df["wyckoff_spglib"] = [get_aflow_label_from_spglib(x) for x in tqdm(df.structure)]

df.reset_index().to_json(
f"{module_dir}/{today}-mp-energies.json.gz", default_handler=as_dict_handler
f"{module_dir}/mp-energies.json.gz", default_handler=as_dict_handler
)

# df = pd.read_json(f"{module_dir}/2022-08-13-mp-energies.json.gz")
Expand All @@ -78,9 +78,9 @@
)

annotate_mae_r2(df.formation_energy_per_atom, df.decomposition_enthalpy)
# result on 2023-01-10: plots match. no correlation between formation energy and decomposition
# enthalpy. R^2 = -1.571, MAE = 1.604
# ax.figure.savefig(f"{module_dir}/{today}-mp-decomp-enth-vs-e-form.webp", dpi=300)
# result on 2023-01-10: plots match. no correlation between formation energy and
# decomposition enthalpy. R^2 = -1.571, MAE = 1.604
# ax.figure.savefig(f"{module_dir}/mp-decomp-enth-vs-e-form.webp", dpi=300)


# %% scatter plot energy above convex hull vs decomposition enthalpy
Expand All @@ -99,4 +99,4 @@
title=f"{n_above_line:,} / {len(df):,} = {n_above_line/len(df):.1%} "
"MP materials with\nenergy_above_hull - decomposition_enthalpy.clip(0) > 0.1"
)
# ax.figure.savefig(f"{module_dir}/{today}-mp-e-above-hull-vs-decomp-enth.webp", dpi=300)
# ax.figure.savefig(f"{module_dir}/mp-e-above-hull-vs-decomp-enth.webp", dpi=300)
7 changes: 0 additions & 7 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,13 +699,6 @@ def cumulative_precision_recall(
facet_col="metric",
facet_col_wrap=2,
facet_col_spacing=0.03,
# pivot df in case we want to show all 3 metrics in each plot's hover tooltip
# requires fixing index mismatch due to df sub-sampling above
# customdata=dict(
# df_cum.reset_index()
# .pivot(index="index", columns="metric")
# .items()
# ),
**kwargs,
)

Expand Down
5 changes: 5 additions & 0 deletions matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@
df_each_pred = pd.DataFrame()
for model in df_metrics.T.MAE.sort_values().index:
df_each_pred[model] = df_wbm[each_true_col] + df_wbm[model] - df_wbm[e_form_col]


df_each_err = pd.DataFrame()
for model in df_metrics.T.MAE.sort_values().index:
df_each_err[model] = df_wbm[model] - df_wbm[e_form_col]
2 changes: 1 addition & 1 deletion scripts/hist_classified_stable_vs_hull_dist_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
for anno in fig.layout.annotations:
if not anno.text.startswith("batch_idx="):
continue
batch_idx = int(anno.text.split("=")[-1])
batch_idx = int(anno.text.split("=", 1)[-1])
len_df = sum(df_wbm[batch_col] == int(batch_idx))
anno.text = f"Batch {batch_idx} ({len_df:,})"

Expand Down
2 changes: 1 addition & 1 deletion scripts/hist_classified_stable_vs_hull_dist_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
ax.set(title=f"{model_name} · {F1=:.2f} · {FPR=:.2f} · {FNR=:.2f} · {DAF=:.2f}")
else:
for anno in fig.layout.annotations:
model_name = anno.text = anno.text.split("=").pop()
model_name = anno.text = anno.text.split("=", 1).pop()
if model_name not in models or not show_metrics:
continue
F1, FPR, FNR, DAF = (
Expand Down
105 changes: 62 additions & 43 deletions scripts/prc_roc_curves_models.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
# %%
import numpy as np
import pandas as pd
from pymatviz.utils import save_fig
from sklearn.metrics import auc, precision_recall_curve, roc_curve
from tqdm import tqdm

from matbench_discovery import FIGS
from matbench_discovery.metrics import stable_metrics
from matbench_discovery.plots import pio
from matbench_discovery.preds import (
df_wbm,
e_form_col,
each_pred_col,
each_true_col,
models,
)
from matbench_discovery.preds import df_each_pred, df_wbm, each_true_col

__author__ = "Janosh Riebesell"
__date__ = "2023-01-30"
Expand All @@ -34,47 +27,49 @@
# %%
df_roc = pd.DataFrame()

for model in (pbar := tqdm(models)):
pbar.set_description(model)
df_wbm[f"{model}_{each_pred_col}"] = df_wbm[each_true_col] + (
df_wbm[model] - df_wbm[e_form_col]
)
for stab_treshold in np.arange(-0.4, 0.4, 0.01):
metrics = stable_metrics(
df_wbm[each_true_col], df_wbm[f"{model}_{each_pred_col}"], stab_treshold
)
df_tmp = pd.DataFrame(
{facet_col: model, color_col: stab_treshold, **metrics}, index=[0]
)
df_roc = pd.concat([df_roc, df_tmp])

for model in (pbar := tqdm(list(df_each_pred), desc="Calculating ROC curves")):
pbar.set_postfix_str(model)
na_mask = df_wbm[each_true_col].isna() | df_each_pred[model].isna()
y_true = (df_wbm[~na_mask][each_true_col] <= 0).astype(int)
y_pred = df_each_pred[model][~na_mask]
fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=0)
AUC = auc(fpr, tpr)
title = f"{model} · {AUC=:.2f}"
df_tmp = pd.DataFrame(
{"FPR": fpr, "TPR": tpr, color_col: thresholds, "AUC": AUC, facet_col: title}
).round(3)

df_roc = df_roc.round(3)
df_roc = pd.concat([df_roc, df_tmp])


# %%
fig = df_roc.plot.scatter(
x="FPR",
y="TPR",
facet_col=facet_col,
facet_col_wrap=2,
backend="plotly",
height=800,
color=color_col,
range_x=(0, 1),
range_y=(0, 1),
fig = (
df_roc.iloc[:: len(df_roc) // 500 or 1]
.sort_values(["AUC", "FPR"], ascending=False)
.plot.scatter(
x="FPR",
y="TPR",
facet_col=facet_col,
facet_col_wrap=2,
backend="plotly",
height=150 * len(df_roc[facet_col].unique()),
color=color_col,
range_x=(0, 1),
range_y=(0, 1),
range_color=(-0.5, 0.5),
hover_name=facet_col,
hover_data={facet_col: False},
)
)

for anno in fig.layout.annotations:
anno.text = anno.text.split("=")[1] # remove Model= from subplot titles
anno.text = anno.text.split("=", 1)[1] # remove Model= from subplot titles

fig.layout.coloraxis.colorbar.update(
x=1, y=1, xanchor="right", yanchor="top", thickness=14, len=0.27, title_side="right"
x=1, y=1, xanchor="right", yanchor="top", thickness=14, len=0.2, title_side="right"
)
fig.add_shape(type="line", x0=0, y0=0, x1=1, y1=1, line=line, row="all", col="all")
fig.add_annotation(
text="No skill", x=0.5, y=0.5, showarrow=False, yshift=-10, textangle=-30
)
fig.add_annotation(text="No skill", x=0.5, y=0.5, showarrow=False, yshift=-10)
# allow scrolling and zooming each subplot individually
fig.update_xaxes(matches=None)
fig.update_yaxes(matches=None)
Expand All @@ -86,20 +81,44 @@


# %%
fig = df_roc.plot.scatter(
df_prc = pd.DataFrame()

for model in (pbar := tqdm(list(df_each_pred), desc="Calculating ROC curves")):
pbar.set_postfix_str(model)
na_mask = df_wbm[each_true_col].isna() | df_each_pred[model].isna()
y_true = (df_wbm[~na_mask][each_true_col] <= 0).astype(int)
y_pred = df_each_pred[model][~na_mask]
prec, recall, thresholds = precision_recall_curve(y_true, y_pred, pos_label=0)
df_tmp = pd.DataFrame(
{
"Precision": prec[:-1],
"Recall": recall[:-1],
color_col: thresholds,
facet_col: model,
}
).round(3)

df_prc = pd.concat([df_prc, df_tmp])


# %%
fig = df_prc.iloc[:: len(df_roc) // 500 or 1].plot.scatter(
x="Recall",
y="Precision",
facet_col=facet_col,
facet_col_wrap=2,
backend="plotly",
height=800,
height=150 * len(df_roc[facet_col].unique()),
color=color_col,
range_x=(0, 1),
range_y=(0, 1),
range_y=(0.5, 1),
range_color=(-0.5, 1),
hover_name=facet_col,
hover_data={facet_col: False},
)

for anno in fig.layout.annotations:
anno.text = anno.text.split("=")[1] # remove Model= from subplot titles
anno.text = anno.text.split("=", 1)[1] # remove Model= from subplot titles

fig.layout.coloraxis.colorbar.update(
x=0.5, y=1.1, thickness=14, len=0.4, orientation="h"
Expand Down
68 changes: 34 additions & 34 deletions scripts/scatter_e_above_hull_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
traces = [t for t in fig.data if t.xaxis == f"x{idx if idx > 1 else ''}"]
assert len(traces) == 4, f"Expected 4 traces, got {len(traces)=}"

model = anno.text.split("=")[1]
model = anno.text.split("=", 1)[1]
assert model in df_wbm, f"Unexpected {model=} not in {list(df_wbm)=}"
# add MAE and R2 to subplot titles
MAE, R2 = df_metrics[model][["MAE", "R2"]]
Expand All @@ -139,39 +139,39 @@
fig.layout[f"xaxis{idx}"].title.text = ""
fig.layout[f"yaxis{idx}"].title.text = ""

# add transparent rectangle with TN, TP, FN, FP labels in each quadrant
for sign_x, sign_y, color, label in zip(
[-1, -1, 1, 1], [-1, 1, -1, 1], clf_colors, ("TP", "FN", "FP", "TN")
):
# instead of coloring points in each quadrant, we can add a transparent
# background to each quadrant (looks worse maybe than coloring points)
# fig.add_shape(
# type="rect",
# x0=0,
# y0=0,
# x1=sign_x * 100,
# y1=sign_y * 100,
# fillcolor=color,
# opacity=0.5,
# layer="below",
# xref=f"x{idx}",
# yref=f"y{idx}",
# )
fig.add_annotation(
xref=f"x{idx}",
yref=f"y{idx}",
x=sign_x * xy_max,
y=sign_y * xy_max,
xshift=-20 * sign_x,
yshift=-20 * sign_y,
text=label,
showarrow=False,
font=dict(size=16, color=color),
)

# add dashed quadrant separators
fig.add_vline(x=0, line=dict(width=0.5, dash="dash"))
fig.add_hline(y=0, line=dict(width=0.5, dash="dash"))
# add transparent rectangle with TN, TP, FN, FP labels in each quadrant
for sign_x, sign_y, color, label in zip(
[-1, -1, 1, 1], [-1, 1, -1, 1], clf_colors, ("TP", "FN", "FP", "TN")
):
# instead of coloring points in each quadrant, we can add a transparent
# background to each quadrant (looks worse maybe than coloring points)
# fig.add_shape(
# type="rect",
# x0=0,
# y0=0,
# x1=sign_x * 100,
# y1=sign_y * 100,
# fillcolor=color,
# opacity=0.2,
# layer="below",
# row="all",
# col="all",
# )
fig.add_annotation(
x=sign_x * xy_max,
y=sign_y * xy_max,
xshift=-20 * sign_x,
yshift=-20 * sign_y,
text=label,
showarrow=False,
font=dict(size=16, color=color),
row="all",
col="all",
)

# add dashed quadrant separators
fig.add_vline(x=0, line=dict(width=0.5, dash="dash"))
fig.add_hline(y=0, line=dict(width=0.5, dash="dash"))

fig.update_xaxes(nticks=5)
fig.update_yaxes(nticks=5)
Expand Down
3 changes: 1 addition & 2 deletions site/src/routes/about-the-test-set/tmi/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@

Stuff that didn't make the cut into the main page describing the WBM test set.

<ColorScaleSelect bind:selected />

<h2>WBM Element Counts for <code>{filter}</code></h2>

Filter WBM element counts by composition arity (how many elements in the formula) or batch
index (which iteration of elemental substitution the structure was generated in).

<ColorScaleSelect bind:selected />
<ul>
<li>
composition arity:
Expand Down
2 changes: 1 addition & 1 deletion site/src/routes/si/+page.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<RocModels />
{/if}

> @label:fig:roc-models Receiver operating characteristic (ROC) curve for each model. TPR/FPR is the true and false positive rates. Points are colored by their stability threshold. A material is classified as stable if E<sub>above hull</sub> is less than the stability threshold. Since all models predict E<sub>form</sub> (and M3GNet predicted energies are converted to formation energy before stability classification), they are insensitive to changes in the threshold.
> @label:fig:roc-models Receiver operating characteristic (ROC) curve for each model. TPR/FPR is the true/false positive rate. FPR means the $x$-axis is the fraction of unstable structures classified as stable while TPR on the $y$-axis is the fraction of stable structures classified as stable. Points are colored by stability threshold $t$ which sweeps from $-0.4 \ \frac{\text{eV}}{\text{atom}} \leq t \leq 0.4 \ \frac{\text{eV}}{\text{atom}}$ above the hull. A material is classified as stable if the predicted E<sub>above hull</sub> is less than the stability threshold. Since all models predict E<sub>form</sub> (and M3GNet predicted energies are converted to formation energy before stability classification), they are insensitive to changes in the threshold $t$. M3GNet wins in area under curve (AUC) with 0.87, coming in 34% higher than the worst model Voronoi Random Forest. The diagonal 'No skill' line shows performance of a dummy model that randomly ranks materials stability.
## Model Run Times

Expand Down
15 changes: 9 additions & 6 deletions tests/test_preds.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from matbench_discovery.data import PRED_FILENAMES
from matbench_discovery.preds import (
df_each_err,
df_each_pred,
df_metrics,
df_wbm,
Expand All @@ -23,9 +24,11 @@ def test_df_metrics() -> None:

def test_df_each_pred() -> None:
assert len(df_each_pred) == len(df_wbm)
assert (
{*df_each_pred} == {*df_metrics} < {*df_wbm}
), "df_each_pred has wrong columns"
assert all(
df_each_pred.isna().sum() / len(df_each_pred) < 0.05
), "too many NaNs in df_each_pred"
assert {*df_each_pred} == {*df_metrics}, "df_each_pred has wrong columns"
assert all(df_each_pred.isna().mean() < 0.05), "too many NaNs in df_each_pred"


def test_df_each_err() -> None:
assert len(df_each_err) == len(df_wbm)
assert {*df_each_err} == {*df_metrics}, "df_each_err has wrong columns"
assert all(df_each_err.isna().mean() < 0.05), "too many NaNs in df_each_err"

0 comments on commit a42472c

Please sign in to comment.