Skip to content

Commit

Permalink
add wandb_log_scatter() to plots.py, use in test_{wrenformer,cgcnn,me…
Browse files Browse the repository at this point in the history
…gnet}

fix unexpected kwarg default_handler=as_dict_handler in voronoi_featurize_dataset.py
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 255429a commit 65172ff
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 40 deletions.
29 changes: 29 additions & 0 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import plotly.io as pio
import scipy.interpolate
import scipy.stats
import wandb
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar

__author__ = "Janosh Riebesell"
Expand Down Expand Up @@ -379,6 +380,7 @@ def cumulative_clf_metric(
axis projection lines.
show_optimal (bool, optional): Whether to plot the optimal precision/recall
line. Defaults to False.
**kwargs: Keyword arguments passed to ax.plot().
Returns:
plt.Axes: The matplotlib axes object.
Expand Down Expand Up @@ -474,3 +476,30 @@ def cumulative_clf_metric(
)

return ax


def wandb_log_scatter(
table: wandb.Table, fields: dict[str, str], **kwargs: Any
) -> None:
"""Log a parity scatter plot using custom vega spec to WandB.
Args:
table (wandb.Table): WandB data table.
fields (dict[str, str]): Map from table columns to fields defined in the custom
vega spec. Currently the only Vega fields are 'x' and 'y'.
**kwargs: Keyword arguments passed to wandb.plot_table(string_fields=kwargs).
"""
assert set(fields) >= {"x", "y"}, f"{fields=} must specify x and y column names"

if all("form" in field for field in fields.values()):
kwargs.setdefault("x", "DFT formation energy (eV/atom)")
kwargs.setdefault("y", "Predicted formation energy (eV/atom)")

scatter_plot = wandb.plot_table(
vega_spec_name="janosh/scatter-parity",
data_table=table,
fields=fields,
string_fields=kwargs,
)

wandb.log({"true_pred_scatter": scatter_plot})
19 changes: 7 additions & 12 deletions models/cgcnn/test_cgcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from matbench_discovery import ROOT
from matbench_discovery.plot_scripts import df_wbm
from matbench_discovery.plots import wandb_log_scatter
from matbench_discovery.slurm import slurm_submit

__author__ = "Janosh Riebesell"
Expand Down Expand Up @@ -67,7 +68,7 @@
assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {filters=}"
for idx, run in enumerate(runs):
for key, val in run.config.items():
if val == runs[0][key] or key.startswith(("slurm_", "timestamp")):
if val == runs[0].config[key] or key.startswith(("slurm_", "timestamp")):
continue
raise ValueError(
f"Configs not identical: runs[{idx}][{key}]={val}, {runs[0][key]=}"
Expand Down Expand Up @@ -106,21 +107,15 @@
)

df.to_csv(f"{log_dir}/{today}-{job_name}-preds.csv", index=False)
table = wandb.Table(dataframe=df)
pred_col = f"{target_col}_pred_ens"
table = wandb.Table(dataframe=df[[target_col, pred_col]].reset_index())


# %%
pred_col = f"{target_col}_pred_ens"
MAE = ensemble_metrics["MAE"]
R2 = ensemble_metrics["R2"]
MAE = ensemble_metrics.MAE.mean()
R2 = ensemble_metrics.R2.mean()

title = rf"CGCNN {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
print(title)

scatter_plot = wandb.plot_table(
vega_spec_name="janosh/scatter-parity",
data_table=table,
fields=dict(x=target_col, y=pred_col, title=title),
)

wandb.log({"true_pred_scatter": scatter_plot})
wandb_log_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)
9 changes: 2 additions & 7 deletions models/megnet/test_megnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from matbench_discovery import ROOT
from matbench_discovery.plot_scripts import df_wbm
from matbench_discovery.plots import wandb_log_scatter
from matbench_discovery.slurm import slurm_submit

"""
Expand Down Expand Up @@ -122,10 +123,4 @@
title = f"{model_name} {task_type} {MAE=:.4} {R2=:.4}"
print(title)

scatter_plot = wandb.plot_table(
vega_spec_name="janosh/scatter-parity",
data_table=table,
fields=dict(x=target_col, y=pred_col, title=title),
)

wandb.log({"true_pred_scatter": scatter_plot})
wandb_log_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)
8 changes: 3 additions & 5 deletions models/voronoi/voronoi_featurize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pymatgen.core import Structure
from tqdm import tqdm

from matbench_discovery import ROOT, as_dict_handler
from matbench_discovery import ROOT
from matbench_discovery.slurm import slurm_submit
from models.voronoi import featurizer

Expand All @@ -34,7 +34,7 @@
account="LEE-SL3-CPU",
time=(slurm_max_job_time := "12:0:0"),
array=f"1-{slurm_array_task_count}",
slurm_flags=("--mem", "20G") if data_name == "mp" else (),
slurm_flags=("--mem", "15G") if data_name == "mp" else (),
log_dir=log_dir,
)

Expand Down Expand Up @@ -90,8 +90,6 @@


# %%
df_features[featurizer.feature_labels()].to_csv(
out_path, default_handler=as_dict_handler
)
df_features[featurizer.feature_labels()].to_csv(out_path)

wandb.log({"voronoi_features": wandb.Table(dataframe=df_features)})
26 changes: 11 additions & 15 deletions models/wrenformer/test_wrenformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,30 @@
from aviary.wrenformer.model import Wrenformer

from matbench_discovery import ROOT
from matbench_discovery.plots import wandb_log_scatter
from matbench_discovery.slurm import slurm_submit

__author__ = "Janosh Riebesell"
__date__ = "2022-08-15"

"""
Download wandb checkpoints for an ensemble of Wrenformer models trained on MP
Download WandB checkpoints for an ensemble of Wrenformer models trained on MP
formation energies, then makes predictions on some dataset, prints ensemble metrics and
stores predictions to CSV.
"""

module_dir = os.path.dirname(__file__)
today = f"{datetime.now():%Y-%m-%d}"
task_type = "IS2RE"
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-summary.csv"
job_name = "wrenformer-wbm-IS2RE"
job_name = "test-wrenformer-wbm-IS2RE"
log_dir = f"{os.path.dirname(__file__)}/{today}-{job_name}"

slurm_vars = slurm_submit(
job_name=job_name,
partition="ampere",
account="LEE-SL3-GPU",
time="2:0:0",
log_dir=module_dir,
log_dir=log_dir,
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
)

Expand All @@ -60,7 +61,7 @@
assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {filters=}"
for idx, run in enumerate(runs):
for key, val in run.config.items():
if val == runs[0][key] or key.startswith(("slurm_", "timestamp")):
if val == runs[0].config[key] or key.startswith(("slurm_", "timestamp")):
continue
raise ValueError(
f"Configs not identical: runs[{idx}][{key}]={val}, {runs[0][key]=}"
Expand Down Expand Up @@ -98,25 +99,20 @@
runs, data_loader=data_loader, df=df, model_cls=Wrenformer, target_col=target_col
)

df.to_csv(f"{module_dir}/{today}-{job_name}-preds.csv")
df.to_csv(f"{log_dir}/{job_name}-preds.csv")


# %%
pred_col = f"{target_col}_pred_ens"
assert pred_col in df, f"{pred_col=} not in {list(df)}"
table = wandb.Table(dataframe=df[[target_col, pred_col]])
table = wandb.Table(dataframe=df[[target_col, pred_col]].reset_index())


# %%
MAE = ensemble_metrics["MAE"]
R2 = ensemble_metrics["R2"]
MAE = ensemble_metrics.MAE.mean()
R2 = ensemble_metrics.R2.mean()

title = rf"Wrenformer {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
fields = dict(x=target_col, y=pred_col, title=title)
print(title)

scatter_plot = wandb.plot_table(
vega_spec_name="janosh/scatter-parity", data_table=table, fields=fields
)

wandb.log({"true_pred_scatter": scatter_plot})
wandb_log_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)
2 changes: 1 addition & 1 deletion models/wrenformer/train_wrenformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# data_path = f"{ROOT}/data/2022-08-25-m3gnet-trainset-mp-2021-struct-energy.json.gz"
# target_col = "mp_energy_per_atom"
data_name = "m3gnet-trainset" if "m3gnet" in data_path else "mp"
run_name = f"train-wrenformer-robust-{data_name}-{target_col}"
run_name = f"train-wrenformer-robust-{data_name}"
n_ens = 10
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
today = timestamp.split("@")[0]
Expand Down

0 comments on commit 65172ff

Please sign in to comment.