Skip to content

Commit

Permalink
add classify_stable() in matbench_discovery/energy.py
Browse files Browse the repository at this point in the history
used by plots cumulative_clf_metric() and hist_classified_stable_vs_hull_dist()
add test_classify_stable()
  • Loading branch information
janosh committed Jun 20, 2023
1 parent c5d3496 commit d593ae2
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 102 deletions.
36 changes: 36 additions & 0 deletions matbench_discovery/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,39 @@ def get_e_form_per_atom(
form_energy = energy - sum(comp[el] * refs[str(el)].energy_per_atom for el in comp)

return form_energy / comp.num_atoms


def classify_stable(
e_above_hull_true: pd.Series,
e_above_hull_pred: pd.Series,
stability_threshold: float = 0,
) -> tuple[pd.Series, pd.Series, pd.Series, pd.Series]:
"""Classify model stability predictions as true/false positive/negatives depending
on if material is actually stable or unstable. All energies are assumed to be in
eV/atom (but shouldn't really matter as long as they're consistent).
Args:
e_above_hull_true (pd.Series): Ground truth energy above convex hull values.
e_above_hull_pred (pd.Series): Model predicted energy above convex hull values.
stability_threshold (float, optional): Maximum energy above convex hull for a
material to still be considered stable. Usually 0, 0.05 or 0.1. Defaults to
0. 0 means a material has to be directly on the hull to be called stable.
Negative values mean a material has to pull the known hull down by that
amount to count as stable. Few materials lie below the known hull, so only
negative values close to 0 make sense.
Returns:
tuple[pd.Series, pd.Series, pd.Series, pd.Series]: Indices for true positives,
false negatives, false positives and true negatives (in this order).
"""
actual_pos = e_above_hull_true <= stability_threshold
actual_neg = e_above_hull_true > stability_threshold
model_pos = e_above_hull_pred <= stability_threshold
model_neg = e_above_hull_pred > stability_threshold

true_pos = actual_pos & model_pos
false_neg = actual_pos & model_neg
false_pos = actual_neg & model_pos
true_neg = actual_neg & model_neg

return true_pos, false_neg, false_pos, true_neg
110 changes: 48 additions & 62 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import wandb
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar

from matbench_discovery.energy import classify_stable

__author__ = "Janosh Riebesell"
__date__ = "2022-08-05"

Expand Down Expand Up @@ -69,8 +71,8 @@


def hist_classified_stable_vs_hull_dist(
e_above_hull_pred: pd.Series,
e_above_hull_true: pd.Series,
e_above_hull_pred: pd.Series,
ax: plt.Axes = None,
which_energy: WhichEnergy = "true",
stability_threshold: float = 0,
Expand All @@ -90,14 +92,14 @@ def hist_classified_stable_vs_hull_dist(
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
Args:
e_above_hull_pred (pd.Series): energy difference to convex hull predicted by
model, i.e. difference between the model's predicted and true formation
energy.
e_above_hull_true (pd.Series): energy diff to convex hull according to DFT
ground truth.
e_above_hull_true (pd.Series): Distance to convex hull according to DFT
ground truth (in eV / atom).
e_above_hull_pred (pd.Series): Distance to convex hull predicted by model
(in eV / atom). Same as true energy to convex hull plus predicted minus true
formation energy.
ax (plt.Axes, optional): matplotlib axes to plot on.
which_energy (WhichEnergy, optional): Whether to use the true formation energy
or the model's predicted formation energy for the histogram.
which_energy (WhichEnergy, optional): Whether to use the true (DFT) hull
distance or the model's predicted hull distance for the histogram.
stability_threshold (float, optional): set stability threshold as distance to
convex hull in eV/atom, usually 0 or 0.1 eV.
show_threshold (bool, optional): Whether to plot stability threshold as dashed
Expand All @@ -114,36 +116,28 @@ def hist_classified_stable_vs_hull_dist(
"""
ax = ax or plt.gca()

test = e_above_hull_pred + e_above_hull_true
# --- histogram of DFT-computed distance to convex hull
if which_energy == "true":
actual_pos = e_above_hull_true <= stability_threshold
actual_neg = e_above_hull_true > stability_threshold
model_pos = test <= stability_threshold
model_neg = test > stability_threshold

n_true_pos = len(e_above_hull_true[actual_pos & model_pos])
n_false_neg = len(e_above_hull_true[actual_pos & model_neg])

n_total_pos = n_true_pos + n_false_neg
null = n_total_pos / len(e_above_hull_true)

true_pos = e_above_hull_true[actual_pos & model_pos]
false_neg = e_above_hull_true[actual_pos & model_neg]
false_pos = e_above_hull_true[actual_neg & model_pos]
true_neg = e_above_hull_true[actual_neg & model_neg]
xlabel = r"$E_\mathrm{above\ hull}$ (eV / atom)"

# --- histogram of model-predicted distance to convex hull
if which_energy == "pred":
true_pos = e_above_hull_pred[actual_pos & model_pos]
false_neg = e_above_hull_pred[actual_pos & model_neg]
false_pos = e_above_hull_pred[actual_neg & model_pos]
true_neg = e_above_hull_pred[actual_neg & model_neg]
xlabel = r"$\Delta E_{Hull-Pred}$ (eV / atom)"
true_pos, false_neg, false_pos, true_neg = classify_stable(
e_above_hull_true, e_above_hull_pred, stability_threshold
)
n_true_pos = sum(true_pos)
n_false_neg = sum(false_neg)

n_total_pos = n_true_pos + n_false_neg
null = n_total_pos / len(e_above_hull_true)

# toggle between histogram of DFT-computed/model-predicted distance to convex hull
e_above_hull = e_above_hull_true if which_energy == "true" else e_above_hull_pred
eah_true_pos = e_above_hull[true_pos]
eah_false_neg = e_above_hull[false_neg]
eah_false_pos = e_above_hull[false_pos]
eah_true_neg = e_above_hull[true_neg]
xlabel = dict(
true="$E_\\mathrm{above\\ hull}$ (eV / atom)",
pred="$E_\\mathrm{above\\ hull\\ pred}$ (eV / atom)",
)[which_energy]

ax.hist(
[true_pos, false_neg, false_pos, true_neg],
[eah_true_pos, eah_false_neg, eah_false_pos, eah_true_neg],
bins=200,
range=x_lim,
alpha=0.5,
Expand All @@ -158,7 +152,7 @@ def hist_classified_stable_vs_hull_dist(
)

n_true_pos, n_false_pos, n_true_neg, n_false_neg = map(
len, (true_pos, false_pos, true_neg, false_neg)
len, (eah_true_pos, eah_false_pos, eah_true_neg, eah_false_neg)
)
# null = (tp + fn) / (tp + tn + fp + fn)
precision = n_true_pos / (n_true_pos + n_false_pos)
Expand All @@ -181,8 +175,8 @@ def hist_classified_stable_vs_hull_dist(
# compute accuracy within 20 meV/atom intervals
bins = np.arange(x_lim[0], x_lim[1], rolling_accuracy)
bin_counts = np.histogram(e_above_hull_true, bins)[0]
bin_true_pos = np.histogram(true_pos, bins)[0]
bin_true_neg = np.histogram(true_neg, bins)[0]
bin_true_pos = np.histogram(eah_true_pos, bins)[0]
bin_true_neg = np.histogram(eah_true_neg, bins)[0]

# compute accuracy
bin_accuracies = (bin_true_pos + bin_true_neg) / bin_counts
Expand Down Expand Up @@ -327,8 +321,8 @@ def rolling_mae_vs_hull_dist(


def cumulative_clf_metric(
e_above_hull_error: pd.Series,
e_above_hull_true: pd.Series,
e_above_hull_pred: pd.Series,
metric: Literal["precision", "recall"],
stability_threshold: float = 0, # set stability threshold as distance to convex
# hull in eV / atom, usually 0 or 0.1 eV
Expand All @@ -344,11 +338,11 @@ def cumulative_clf_metric(
predicted stable are included.
Args:
df (pd.DataFrame): Model predictions and target energy values.
e_above_hull_error (str, optional): Column name with residuals of model
predictions, i.e. residual = pred - target. Defaults to "residual".
e_above_hull_true (str, optional): Column name with convex hull distance values.
Defaults to "e_above_hull".
e_above_hull_true (pd.Series): Distance to convex hull according to DFT
ground truth (in eV / atom).
e_above_hull_pred (pd.Series): Distance to convex hull predicted by model
(in eV / atom). Same as true energy to convex hull plus predicted minus true
formation energy.
metric ('precision' | 'recall', optional): Metric to plot.
stability_threshold (float, optional): Max distance from convex hull before
material is considered unstable. Defaults to 0.
Expand All @@ -365,25 +359,19 @@ def cumulative_clf_metric(
"""
ax = ax or plt.gca()

e_above_hull_error = e_above_hull_error.sort_values()
e_above_hull_true = e_above_hull_true.loc[e_above_hull_error.index]
e_above_hull_pred = e_above_hull_pred.sort_values()
e_above_hull_true = e_above_hull_true.loc[e_above_hull_pred.index]

true_pos_mask = (e_above_hull_true <= stability_threshold) & (
e_above_hull_error <= stability_threshold
)
false_neg_mask = (e_above_hull_true <= stability_threshold) & (
e_above_hull_error > stability_threshold
)
false_pos_mask = (e_above_hull_true > stability_threshold) & (
e_above_hull_error <= stability_threshold
true_pos, false_neg, false_pos, _true_neg = classify_stable(
e_above_hull_true, e_above_hull_pred, stability_threshold
)

true_pos_cumsum = true_pos_mask.cumsum()
true_pos_cumsum = true_pos.cumsum()

# precision aka positive predictive value (PPV)
precision = true_pos_cumsum / (true_pos_cumsum + false_pos_mask.cumsum()) * 100
n_true_pos = sum(true_pos_mask)
n_false_neg = sum(false_neg_mask)
precision = true_pos_cumsum / (true_pos_cumsum + false_pos.cumsum()) * 100
n_true_pos = sum(true_pos)
n_false_neg = sum(false_neg)
n_total_pos = n_true_pos + n_false_neg
true_pos_rate = true_pos_cumsum / n_total_pos * 100

Expand Down Expand Up @@ -443,9 +431,7 @@ def cumulative_clf_metric(
return ax


def wandb_log_scatter(
table: wandb.Table, fields: dict[str, str], **kwargs: Any
) -> None:
def wandb_scatter(table: wandb.Table, fields: dict[str, str], **kwargs: Any) -> None:
"""Log a parity scatter plot using custom vega spec to WandB.
Args:
Expand Down
4 changes: 2 additions & 2 deletions models/cgcnn/test_cgcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from matbench_discovery import DEBUG, ROOT, today
from matbench_discovery.load_preds import df_wbm
from matbench_discovery.plots import wandb_log_scatter
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit

__author__ = "Janosh Riebesell"
Expand Down Expand Up @@ -124,4 +124,4 @@

title = rf"CGCNN {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"

wandb_log_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)
wandb_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)
4 changes: 2 additions & 2 deletions models/megnet/test_megnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from matbench_discovery import DEBUG, ROOT, timestamp, today
from matbench_discovery.load_preds import df_wbm
from matbench_discovery.plots import wandb_log_scatter
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit

"""
Expand Down Expand Up @@ -115,4 +115,4 @@
title = f"{model_name} {task_type} {MAE=:.4} {R2=:.4}"
print(title)

wandb_log_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)
wandb_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)
4 changes: 2 additions & 2 deletions models/voronoi/train_test_voronoi_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from matbench_discovery import DEBUG, ROOT, today
from matbench_discovery.load_preds import df_wbm, glob_to_df
from matbench_discovery.plots import wandb_log_scatter
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit
from models.voronoi import featurizer

Expand Down Expand Up @@ -127,4 +127,4 @@
title = f"{model_name} {task_type} {MAE=:.3} {R2=:.3}"
print(title)

wandb_log_scatter(table, fields=dict(x=test_target_col, y=pred_col), title=title)
wandb_scatter(table, fields=dict(x=test_target_col, y=pred_col), title=title)
4 changes: 2 additions & 2 deletions models/wrenformer/test_wrenformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from aviary.wrenformer.model import Wrenformer

from matbench_discovery import DEBUG, ROOT, today
from matbench_discovery.plots import wandb_log_scatter
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit

__author__ = "Janosh Riebesell"
Expand Down Expand Up @@ -110,4 +110,4 @@

title = rf"Wrenformer {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"

wandb_log_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)
wandb_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)
10 changes: 8 additions & 2 deletions scripts/hist_classified_stable_vs_hull_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

# %%
target_col = "e_form_per_atom_mp2020_corrected"
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
which_energy: WhichEnergy = "true"
# std_factor=0,+/-1,+/-2,... changes the criterion for material stability to
# energy+std_factor*std. energy+std means predicted energy plus the model's uncertainty
Expand All @@ -40,10 +41,15 @@
var_epistemic = df_wbm.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
std_total = (var_epistemic + var_aleatoric) ** 0.5
std_total = df_wbm[f"{model_name}_std"]
e_above_hull_pred = (
df_wbm[e_above_hull_col]
+ (df_wbm[model_name] + std_factor * std_total)
- df_wbm[target_col]
)

ax, metrics = hist_classified_stable_vs_hull_dist(
e_above_hull_pred=df_wbm[model_name] - std_factor * std_total - df_wbm[target_col],
e_above_hull_true=df_wbm.e_above_hull_mp2020_corrected_ppd_mp,
e_above_hull_true=df_wbm[e_above_hull_col],
e_above_hull_pred=e_above_hull_pred,
which_energy=which_energy,
# stability_threshold=-0.05,
rolling_accuracy=0,
Expand Down
10 changes: 6 additions & 4 deletions scripts/hist_classified_stable_vs_hull_dist_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
assert 1e4 < len(batch_df) < 1e5, print(f"{len(batch_df) = :,}")

ax, metrics = hist_classified_stable_vs_hull_dist(
e_above_hull_pred=batch_df[model_name] - batch_df[target_col],
e_above_hull_true=batch_df[e_above_hull_col],
e_above_hull_pred=batch_df[e_above_hull_col]
+ (batch_df[model_name] - batch_df[target_col]),
which_energy=which_energy,
ax=ax,
)
Expand All @@ -53,8 +54,9 @@


ax, metrics = hist_classified_stable_vs_hull_dist(
e_above_hull_pred=df_wbm[model_name] - df_wbm[target_col],
e_above_hull_true=df_wbm[e_above_hull_col],
e_above_hull_pred=df_wbm[e_above_hull_col]
+ (df_wbm[model_name] - df_wbm[target_col]),
which_energy=which_energy,
ax=axs.flat[-1],
)
Expand All @@ -69,5 +71,5 @@


# %%
img_name = f"{today}-{model_name}-wbm-hull-dist-hist-batches"
ax.figure.savefig(f"{ROOT}/figures/{img_name}.pdf")
img_path = f"{ROOT}/figures/{today}-{model_name}-wbm-hull-dist-hist-batches.pdf"
# ax.figure.savefig(img_path)
29 changes: 8 additions & 21 deletions scripts/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,26 @@

for model_name, color in zip(models, colors):

e_above_hull_pred = df_wbm[model_name] - df_wbm[target_col]

F1 = f1_score(df_wbm[e_above_hull_col] < 0, e_above_hull_pred < 0)

e_above_hull_error = e_above_hull_pred + df_wbm[e_above_hull_col]
cumulative_clf_metric(
e_above_hull_error,
df_wbm[e_above_hull_col],
color=color,
label=f"{model_name}\n{F1=:.3}",
project_end_point="xy",
ax=ax_prec,
metric="precision",
e_above_hull_pred = df_wbm[e_above_hull_col] + (
df_wbm[model_name] - df_wbm[target_col]
)

cumulative_clf_metric(
e_above_hull_error,
df_wbm[e_above_hull_col],
F1 = f1_score(df_wbm[e_above_hull_col] < 0, e_above_hull_pred < 0)
in_common = dict(
e_above_hull_true=df_wbm[e_above_hull_col],
e_above_hull_pred=e_above_hull_pred,
color=color,
label=f"{model_name}\n{F1=:.3}",
project_end_point="xy",
ax=ax_recall,
metric="recall",
)
cumulative_clf_metric(**in_common, ax=ax_prec, metric="precision")

cumulative_clf_metric(**in_common, ax=ax_recall, metric="recall")

for ax in (ax_prec, ax_recall):
ax.set(xlim=(0, None))


# x-ticks every 10k materials
# ax.set(xticks=range(0, int(ax.get_xlim()[1]), 10_000))

fig.suptitle(f"{today} {model_name}")
xlabel_cumulative = "Materials predicted stable sorted by hull distance"
fig.text(0.5, -0.08, xlabel_cumulative, ha="center")
Expand Down
Loading

0 comments on commit d593ae2

Please sign in to comment.