Skip to content

Commit

Permalink
add figs/cumulative-mae-rmse.svelte displayed in /si
Browse files Browse the repository at this point in the history
rename scripts/(cumulative_clf_metrics.py -> cumulative_metrics).py
rename figs/(cumulative-clf-metrics -> site/src/figs/cumulative-precision-recall).svelte
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 70b2b1d commit 8798786
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 50 deletions.
60 changes: 43 additions & 17 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,10 +591,10 @@ def rolling_mae_vs_hull_dist(
return fig, df_rolling_err, df_err_std


def cumulative_precision_recall(
def cumulative_metrics(
e_above_hull_true: pd.Series,
df_preds: pd.DataFrame,
metrics: Sequence[str] = ("Cumulative Precision", "Cumulative Recall"),
metrics: Sequence[str] = ("Precision", "Recall"),
stability_threshold: float = 0, # set stability threshold as distance to convex
# hull in eV / atom, usually 0 or 0.1 eV
project_end_point: Literal["x", "y", "xy", ""] = "xy",
Expand Down Expand Up @@ -638,32 +638,57 @@ def cumulative_precision_recall(
"""
factory = lambda: pd.DataFrame(index=range(len(e_above_hull_true)))
dfs: dict[str, pd.DataFrame] = defaultdict(factory)
metrics_no_case = [*map(str.casefold, metrics)]

valid_metrics = {"precision", "recall", "f1", "mae", "rmse"}
if invalid_metrics := set(metrics_no_case) - valid_metrics:
raise ValueError(
f"{invalid_metrics=}, should be case-insensitive subset of {valid_metrics=}"
)

for model_name in df_preds:
each_pred = df_preds[model_name].sort_values()
# sort targets by model ranking
each_true = e_above_hull_true.loc[each_pred.index]

true_pos_cum, false_neg_cum, false_pos_cum, _true_neg_cum = map(
true_pos_cum, false_neg_cum, false_pos_cum, true_neg_cum = map(
np.cumsum, classify_stable(each_true, each_pred, stability_threshold)
)

# precision aka positive predictive value (PPV)
precision_cum = true_pos_cum / (true_pos_cum + false_pos_cum)
n_total_pos = true_pos_cum[-1] + false_neg_cum[-1]
precision_cum = true_pos_cum / (true_pos_cum + false_pos_cum)
recall_cum = true_pos_cum / n_total_pos # aka true_pos_rate aka sensitivity
# cumulative F1 score
f1_cum = 2 * (precision_cum * recall_cum) / (precision_cum + recall_cum)

end = int(np.argmax(recall_cum))
xs = np.arange(end)
prec_interp = scipy.interpolate.interp1d(xs, precision_cum[:end], kind="cubic")
recall_interp = scipy.interpolate.interp1d(xs, recall_cum[:end], kind="cubic")
f1_interp = scipy.interpolate.interp1d(xs, f1_cum[:end], kind="cubic")

dfs["Cumulative Precision"][model_name] = pd.Series(prec_interp(xs))
dfs["Cumulative Recall"][model_name] = pd.Series(recall_interp(xs))
dfs["Cumulative F1"][model_name] = pd.Series(f1_interp(xs))
if "precision" in metrics_no_case:
prec_interp = scipy.interpolate.interp1d(
xs, precision_cum[:end], kind="cubic"
)
dfs["Precision"][model_name] = pd.Series(prec_interp(xs))
if "recall" in metrics_no_case:
recall_interp = scipy.interpolate.interp1d(
xs, recall_cum[:end], kind="cubic"
)
dfs["Recall"][model_name] = pd.Series(recall_interp(xs))
if "f1" in metrics_no_case:
f1_cum = 2 * (precision_cum * recall_cum) / (precision_cum + recall_cum)
f1_interp = scipy.interpolate.interp1d(xs, f1_cum[:end], kind="cubic")
dfs["F1"][model_name] = pd.Series(f1_interp(xs))

if "mae" in metrics_no_case:
cum_errors = (each_true - each_pred).abs().cumsum()
cum_counts = np.arange(1, len(each_true) + 1)
mae_cum = cum_errors / cum_counts
mae_interp = scipy.interpolate.interp1d(xs, mae_cum[:end], kind="cubic")
dfs["MAE"][model_name] = pd.Series(mae_interp(xs))

if "rmse" in metrics_no_case:
rmse_cum = (((each_true - each_pred) ** 2).cumsum() / cum_counts) ** 0.5
rmse_interp = scipy.interpolate.interp1d(xs, rmse_cum[:end], kind="cubic")
dfs["RMSE"][model_name] = pd.Series(rmse_interp(xs))

for key in dfs:
# drop all-NaN rows so plotly plot x-axis only extends to largest number of
Expand Down Expand Up @@ -736,19 +761,20 @@ def cumulative_precision_recall(
)

elif backend == "plotly":
fig = df_cum.query(f"metric in {metrics}").plot(
n_cols = kwargs.pop("facet_col_wrap", 2)
kwargs.setdefault("facet_col_spacing", 0.03)
fig = df_cum.plot(
backend=backend,
facet_col="metric",
facet_col_wrap=2,
facet_col_spacing=0.03,
facet_col_wrap=n_cols,
**kwargs,
)

line_kwds = dict(dash="dash", width=0.5)
for idx, anno in enumerate(fig.layout.annotations):
anno.text = anno.text.split("=")[1]
anno.font.size = 16
grid_pos = dict(row=idx // 2 + 1, col=idx % 2 + 1)
grid_pos = dict(row=idx // n_cols + 1, col=idx % n_cols + 1)
fig.update_traces(
hovertemplate=f"Index = %{{x:d}}<br>{anno.text} = %{{y:.2f}}",
**grid_pos,
Expand All @@ -774,7 +800,7 @@ def cumulative_precision_recall(
fig.add_annotation(
x=n_stable,
y=0.95,
text="Stable<br>Materials",
text="Stable Materials",
showarrow=False,
xanchor="left",
align="left",
Expand Down
44 changes: 29 additions & 15 deletions scripts/cumulative_clf_metrics.py → scripts/cumulative_metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Plot cumulative precision and/or recall and/or F1 curves for all models into facet
plot with one subplot per metric. Cumulative here means going through the list of WBM
materials ranked by the model's stability prediction starting from the most stable
and updating the precision, recall and F1 score after each new material. This plot
"""Plot cumulative metrics like precision, recall, F1, MAE, RMSE as lines for all models
into face plot with one subplot per metric. Cumulative here means descending the list of
test set materials ranked by model-predicted stability starting from the most stable
and updating the metric (Recall, MAE, etc.) after each new material. This plot
simulates an actual materials screening process and allows practitioners to choose
a cutoff point for the number of DFT calculations they have budget and see which model
will provide the best hit rate for the given budget.
Expand All @@ -13,30 +13,43 @@
from pymatviz.utils import save_fig

from matbench_discovery import FIGS, PDF_FIGS
from matbench_discovery.plots import cumulative_precision_recall
from matbench_discovery.preds import df_each_pred, df_metrics, df_preds, each_true_col
from matbench_discovery.plots import cumulative_metrics
from matbench_discovery.preds import (
df_each_pred,
df_metrics,
df_preds,
each_true_col,
models,
)

__author__ = "Janosh Riebesell, Rhys Goodall"
__date__ = "2022-12-04"


# %%
fig, df_metric = cumulative_precision_recall(
# metrics = ("Precision", "Recall")
metrics = ("MAE", "RMSE")
fig, df_metric = cumulative_metrics(
e_above_hull_true=df_preds[each_true_col],
df_preds=df_each_pred,
df_preds=df_each_pred[models],
project_end_point="xy",
backend=(backend := "plotly"),
range_y=(0, 1)
# template="plotly_white",
range_y=(0, 0.4),
metrics=metrics,
# facet_col_wrap=2,
# increase facet col gap
facet_col_spacing=0.07,
)

x_label = "Number of screened WBM materials"
if backend == "matplotlib":
# fig.suptitle(title)
fig.text(0.5, -0.08, x_label, ha="center", fontdict={"size": 16})
if backend == "plotly":
fig.layout.legend.update(x=0, y=0, bgcolor="rgba(0,0,0,0)")
fig.layout.margin.update(l=0, r=5, t=30, b=50)
fig.layout.legend.update(
x=1, y=1, bgcolor="rgba(0,0,0,0)", xanchor="right", yanchor="top"
)
fig.layout.margin.update(l=0, r=0, t=30, b=50)
fig.add_annotation(
x=0.5,
y=-0.15,
Expand All @@ -48,8 +61,9 @@
)
fig.update_traces(line=dict(width=3))
for trace in fig.data:
# show only the N best models by default
if trace.name in df_metrics.T.sort_values("F1").index[:-6]:
trace.visible = "legendonly" # show only top models by default
trace.visible = "legendonly"
last_idx = pd.Series(trace.y).last_valid_index()
last_x = trace.x[last_idx]
last_y = trace.y[last_idx]
Expand Down Expand Up @@ -89,6 +103,6 @@


# %%
img_name = "cumulative-clf-metrics"
img_name = f"cumulative-{'-'.join(metrics).lower()}"
save_fig(fig, f"{FIGS}/{img_name}.svelte")
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=720, height=370)
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=900, height=330)
1 change: 0 additions & 1 deletion scripts/rolling_mae_vs_hull_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
e_above_hull_errors={model: df_preds[e_form_col] - df_preds[model]},
# label=model,
backend=(backend := "plotly"),
# template="plotly_white",
)

MAE, DAF, F1 = df_metrics[model][["MAE", "DAF", "F1"]]
Expand Down
1 change: 1 addition & 0 deletions site/src/figs/cumulative-mae-rmse.svelte

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions site/src/figs/cumulative-precision-recall.svelte

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions site/src/routes/preprint/+page.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import { repository as repo } from '$site/package.json'
import EachScatterModels from '$figs/each-scatter-models-4x2.svelte'
import MetricsTable from '$figs/metrics-table.svelte'
import CumulativeClfMetrics from '$figs/cumulative-clf-metrics.svelte'
import CumulativePrecisionRecall from '$figs/cumulative-precision-recall.svelte'
import RollingMaeVsHullDistModels from '$figs/rolling-mae-vs-hull-dist-models.svelte'
import ElementErrorsPtableHeatmap from '$models/element-errors-ptable-heatmap.svelte'
import HistClfTrueHullDistModels from '$figs/hist-clf-true-hull-dist-models-4x2.svelte'
Expand Down Expand Up @@ -166,23 +166,23 @@ Our initial benchmark release includes 8 models.
@Fig:metrics-table shows performance metrics for all models included in the initial release of Matbench Discovery.
CHGNet takes the top spot on all metrics except true positive rate (TPR) and emerges as current SOTA for ML-guided materials discovery. The discovery acceleration factor (DAF) measures how many more stable structures a model found compared to the dummy discovery rate of 43k / 257k $\approx$ 16.7\% achieved by randomly selecting test set crystals. Consequently, the maximum possible DAF is ~6. This highlights the fact that our benchmark is made more challenging by deploying models on an already enriched space with a much higher fraction of stable structures than uncharted materials space at large. As the convex hull becomes more thoroughly sampled by future discovery, the fraction of unknown stable structures decreases, naturally leading to less enriched future test sets which will allow for higher maximum DAFs.

Note that MEGNet outperforms M3GNet on DAF (2.70 vs 2.66) even though M3GNet is superior to MEGNet in all other metrics. The reason is the one outlined in the previous paragraph as becomes clear from @fig:cumulative-clf-metrics. MEGNet's line ends at 55.6 k materials which is closest to the true number of 43 k stable materials in our test set. All other models overpredict the sum total of stable materials by anywhere from 40% (~59 k for CGCNN) to 104% (85 k for Wrenformer), resulting in large numbers of false positive predictions which lower their DAFs.
Note that MEGNet outperforms M3GNet on DAF (2.70 vs 2.66) even though M3GNet is superior to MEGNet in all other metrics. The reason is the one outlined in the previous paragraph as becomes clear from @fig:cumulative-precision-recall. MEGNet's line ends at 55.6 k materials which is closest to the true number of 43 k stable materials in our test set. All other models overpredict the sum total of stable materials by anywhere from 40% (~59 k for CGCNN) to 104% (85 k for Wrenformer), resulting in large numbers of false positive predictions which lower their DAFs.

As noted, this is only a problem in practice for exhaustive discovery campaigns that validate _all_ stable predictions from a model. More frequently, model predictions will be ranked most-to-least stable and validation stops after some pre-determined compute budget is spent, say, 10k DFT relaxations. In that case, most of the false positive predictions near the less stable end of the candidate list are ignored and don't harm the campaign's overall discovery count.

We find a large performance gap between models that make one-shot predictions directly from unrelaxed inputs such as MEGNet, Wrenformer, CGCNN, CGCNN+P, Voronoi RF versus UIPs that predict forces to emulate DFT relaxation. While the F1 scores and DAFs of non-UIPs are seemingly unaffected, their $R^2$ coefficients are significantly worse. Except for CGCNN+P, all fail to achieve positive $R^2$. This means their predictions explain the observed variation in the data less than a horizontal line through the test set mean. In other words, these models are not predictive in a global sense (across the full dataset range). However, even models with negative $R^2$ can be locally good in the positive and negative tails of the test set hull distance distribution. They suffer most in the mode near the stability threshold of 0 eV/atom above the hull. This reveals an important shortcoming of $R^2$ as a metric for classification tasks like ours.

The reason CGCNN+P achieves better regression metrics than CGCNN but is still worse as a classifier becomes apparent from [the SI histograms](/si#fig:hist-clf-pred-hull-dist-models) by noting that the CGCNN+P histogram is more sharply peaked at the 0 hull distance stability threshold. This causes even small errors in the predicted convex hull distance to be large enough to invert a classification. Again, this is evidence to choose carefully which metrics to optimize. Regression metrics are far more prevalent when evaluating energy predictions. In our benchmark, energies are just means to an end to classify compound stability. Regression accuracy is of little use on its own unless it helps classification. The field needs to be aware that this is not a given.

### Cumulative Classification Metrics
### Cumulative Precision + Recall

{#if mounted}
<CumulativeClfMetrics style="margin: 0 -2em 0 -4em;" />
<CumulativePrecisionRecall style="margin: 0 -2em 0 -4em;" />
{/if}

> @label:fig:cumulative-clf-metrics Cumulative precision and recall over the course of a simulated discovery campaign. This figure highlights how different models perform better or worse depending on the length of the discovery campaign. Length here is an integer measuring how many DFT relaxations you have compute budget for. We only show the 6 best performing models for visual clarity.
> @label:fig:cumulative-precision-recall Cumulative precision and recall over the course of a simulated discovery campaign. This figure highlights how different models perform better or worse depending on the length of the discovery campaign. Length here is an integer measuring how many DFT relaxations you have compute budget for. We only show the 6 best performing models for visual clarity.
@Fig:cumulative-clf-metrics simulates ranking materials from most to least stable according to model-predicted energies. For each model, we go down that list material by material, calculating at each step the precision and recall of correctly identified stable materials. This simulates exactly how these models might be used in a prospective materials discovery campaign and reveal how a model's performance changes as a function of the discovery campaign length, i.e. the amount of resources available to validate model predictions.
@Fig:cumulative-precision-recall simulates ranking materials from most to least stable according to model-predicted energies. For each model, we go down that list material by material, calculating at each step the precision and recall of correctly identified stable materials. This simulates exactly how these models might be used in a prospective materials discovery campaign and reveal how a model's performance changes as a function of the discovery campaign length, i.e. the amount of resources available to validate model predictions.

A line terminates when a model believes there are no more materials in the WBM test set below the MP convex hull. The dashed vertical line shows the actual number of stable structures in our test set. All models are biased towards stability to some degree as they all overestimate this number, most of all Wrenformer by over 112%, least of all MEGNet by 30%. The diagonal Optimal Recall line would be achieved if a model never made a false negative prediction and stops predicting stable crystals exactly when the true number of stable materials is reached. Zooming in on the top-left corner of the precision plot, we observe that CHGNet is the only model without a sudden drop in precision right at the start of the discovery campaign. It keeps a strong lead over the runner-up M3GNet until reaching ~3k screened materials. From there, the CHGNet and M3GNet lines slowly converge until they almost tie at a precision of 0.52 after ~56k screened materials. At that point, CHGNet's list of stable predictions is exhausted while M3GNet continues to drop to 0.45 at 76 k, attributable to many false positives towards the end of the stable list.

Expand Down
Loading

0 comments on commit 8798786

Please sign in to comment.