Skip to content

Commit

Permalink
move /paper/preprint to /paper
Browse files Browse the repository at this point in the history
add CHGNet to /si rolling MAE batches model comparison
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 0fad3bd commit 8a0c3bb
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 101 deletions.
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.258
rev: v0.0.260
hooks:
- id: ruff
args: [--fix]

- repo: https://github.com/psf/black
rev: 23.1.0
rev: 23.3.0
hooks:
- id: black

Expand All @@ -34,13 +34,13 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.1
rev: v1.1.1
hooks:
- id: mypy
additional_dependencies: [types-pyyaml, types-requests]

- repo: https://github.com/codespell-project/codespell
rev: v2.2.2
rev: v2.2.4
hooks:
- id: codespell
stages: [commit, commit-msg]
Expand All @@ -49,7 +49,7 @@ repos:
args: [--ignore-words-list, "nd,te,fpr"]

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.0-alpha.4
rev: v3.0.0-alpha.6
hooks:
- id: prettier
args: [--write] # edit files in-place
Expand All @@ -60,7 +60,7 @@ repos:
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yml|yaml|json))$

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v8.34.0
rev: v8.37.0
hooks:
- id: eslint
types: [file]
Expand Down
9 changes: 7 additions & 2 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def rolling_mae_vs_hull_dist(
with_sem: bool = True,
show_dft_acc: bool = False,
show_dummy_mae: bool = False,
pbar: bool = True,
**kwargs: Any,
) -> plt.Axes | go.Figure:
"""Rolling mean absolute error as the energy to the convex hull is varied. A scale
Expand Down Expand Up @@ -380,6 +381,8 @@ def rolling_mae_vs_hull_dist(
meV/atom. Defaults to False.
show_dummy_mae (bool, optional): If True, plot a line at the dummy MAE of always
predicting the target mean.
pbar (bool, optional): If True, show a progress bar during rolling MAE
calculation. Defaults to True.
**kwargs: Additional keyword arguments to pass to df.plot().
Returns:
Expand All @@ -396,8 +399,10 @@ def rolling_mae_vs_hull_dist(
df_rolling_err = pd.DataFrame(columns=models, index=bins)
df_err_std = df_rolling_err.copy()

for model in (pbar := tqdm(models, desc="Calculating rolling MAE")):
pbar.set_postfix_str(model)
for model in (
prog_bar := tqdm(models, desc="Calculating rolling MAE", disable=not pbar)
):
prog_bar.set_postfix_str(model)
for idx, bin_center in enumerate(bins):
low = bin_center - window
high = bin_center + window
Expand Down
2 changes: 1 addition & 1 deletion scripts/hist_classified_stable_vs_hull_dist_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
df_preds[each_pred_col] = (
df_preds[each_true_col] + df_preds[model_name] - df_preds[e_form_col]
)
df_preds[(batch_col := "batch_idx")] = df_preds.index.str.split("-").str[-2].astype(int)
df_preds[(batch_col := "batch_idx")] = df_preds.index.str.split("-").str[1].astype(int)


# %% matplotlib
Expand Down
11 changes: 4 additions & 7 deletions scripts/hist_classified_stable_vs_hull_dist_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
from pymatviz.utils import save_fig

from matbench_discovery import FIGS, ROOT, today
from matbench_discovery.plots import (
hist_classified_stable_vs_hull_dist,
plt,
)
from matbench_discovery.plots import hist_classified_stable_vs_hull_dist, plt
from matbench_discovery.preds import df_metrics, df_preds, e_form_col, each_true_col

__author__ = "Janosh Riebesell"
Expand All @@ -25,8 +22,8 @@
e_form_preds = "e_form_per_atom_pred"
each_pred_col = "e_above_hull_pred"
facet_col = "Model"
models = list(df_metrics)
# models = df_metrics.T.MAE.nsmallest(6).index # top 6 models by MAE
# sort facet plots by model's F1 scores (optionally only show top n=6)
models = list(df_metrics.T.F1.sort_values().index)[::-1]

df_melt = df_preds.melt(
id_vars=hover_cols,
Expand All @@ -45,7 +42,7 @@
rows, cols = len(models) // 2, 2
which_energy: Final = "true"
kwds = (
dict(facet_col=facet_col, facet_col_wrap=cols)
dict(facet_col=facet_col, facet_col_wrap=cols, category_orders={facet_col: models})
if backend == "plotly"
else dict(by=facet_col, figsize=(20, 20), layout=(rows, cols), bins=500)
)
Expand Down
3 changes: 1 addition & 2 deletions scripts/prc_roc_curves_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
from tqdm import tqdm

from matbench_discovery import FIGS
from matbench_discovery.plots import pio
from matbench_discovery import plots as plots
from matbench_discovery.preds import df_each_pred, df_preds, each_true_col

__author__ = "Janosh Riebesell"
__date__ = "2023-01-30"


pio.templates.default
line = dict(dash="dash", width=0.5)

facet_col = "Model"
Expand Down
8 changes: 4 additions & 4 deletions scripts/rolling_mae_vs_hull_dist_wbm_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
__author__ = "Rhys Goodall, Janosh Riebesell"
__date__ = "2022-06-18"

df_each_pred[(batch_col := "batch_idx")] = (
"Batch " + df_each_pred.index.str.split("-").str[1]
)
batch_col = "batch_idx"
df_each_pred[batch_col] = "Batch " + df_each_pred.index.str.split("-").str[1]
df_err, df_std = None, None # variables to cache rolling MAE and std


Expand Down Expand Up @@ -43,6 +42,7 @@
backend="matplotlib",
ax=ax,
just_plot_lines=idx > 1,
pbar=False,
)


Expand All @@ -54,7 +54,7 @@


# %% plotly
model = "Wrenformer" # ["M3GNet", "Wrenformer", "MEGNet", "Voronoi RF"]
model = "CHGNet"
df_pivot = df_each_pred.pivot(columns=batch_col, values=model)

# unstack two-level column index into new model column
Expand Down

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion site/src/figs/hist-clf-true-hull-dist-models.svelte

Large diffs are not rendered by default.

Loading

0 comments on commit 8a0c3bb

Please sign in to comment.