Skip to content

Commit

Permalink
add missing doc strings
Browse files Browse the repository at this point in the history
reduce max line length 95->88
fix new overlong lines
  • Loading branch information
janosh committed Jun 20, 2023
1 parent 81dc5f0 commit 1ae72c0
Show file tree
Hide file tree
Showing 30 changed files with 162 additions and 93 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.249
rev: v0.0.252
hooks:
- id: ruff
args: [--fix]
Expand Down
5 changes: 5 additions & 0 deletions data/mp/build_phase_diagram.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Build a PatchedPhaseDiagram from all MP ComputedStructureEntries for calculating
DFT-ground truth convex hull energies.
"""


# %%
import gzip
import json
Expand Down
8 changes: 4 additions & 4 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
default_cache_dir = os.path.expanduser("~/.cache/matbench-discovery")

DATA_FILENAMES = {
"mp-computed-structure-entries": "mp/2022-09-16-mp-computed-structure-entries.json.gz",
"mp-computed-structure-entries": "mp/2022-09-16-mp-computed-structure-entries.json.gz", # noqa: E501
"mp-elemental-ref-energies": "mp/2022-09-19-mp-elemental-ref-energies.json",
"mp-energies": "mp/2022-08-13-mp-energies.json.gz",
"mp-patched-phase-diagram": "mp/2023-02-07-ppd-mp.pkl.gz",
"wbm-computed-structure-entries": "wbm/2022-10-19-wbm-computed-structure-entries.json.bz2",
"wbm-computed-structure-entries": "wbm/2022-10-19-wbm-computed-structure-entries.json.bz2", # noqa: E501
"wbm-initial-structures": "wbm/2022-10-19-wbm-init-structs.json.bz2",
"wbm-summary": "wbm/2022-10-19-wbm-summary.csv",
}
Expand Down Expand Up @@ -60,10 +60,10 @@ def load_train_test(
Recognized data keys are mp-computed-structure-entries, mp-elemental-ref-energies,
mp-energies, mp-patched-phase-diagram, wbm-computed-structure-entries,
wbm-initial-structures, wbm-summary. See
https://janosh.github.io/matbench-discovery/how-to-contribute for brief data descriptions.
https://janosh.github.io/matbench-discovery/how-to-contribute for data descriptions.
Args:
data_names (str | list[str], optional): Which parts of the MP/WBM dataset to load.
data_names (str | list[str], optional): Which parts of the MP/WBM data to load.
Can be any subset of the above data names or 'all'. Defaults to ["summary"].
version (str, optional): Which version of the dataset to load. Defaults to
'1.0.0'. Can be any git tag, branch or commit hash.
Expand Down
2 changes: 1 addition & 1 deletion matbench_discovery/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def get_elemental_ref_entries(
entries: Sequence[EntryLike], verbose: bool = True
) -> dict[str, Entry]:
"""Get the lowest energy pymatgen Entry object for each element in a list of entries.
"""Get the lowest energy pymatgen Entry for each element in a list of entries.
Args:
entries (Sequence[Entry]): pymatgen Entries (PDEntry, ComputedEntry or
Expand Down
12 changes: 6 additions & 6 deletions matbench_discovery/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ def classify_stable(
Args:
e_above_hull_true (pd.Series): Ground truth energy above convex hull values.
e_above_hull_pred (pd.Series): Model predicted energy above convex hull values.
stability_threshold (float | None, optional): Maximum energy above convex hull for a
material to still be considered stable. Usually 0, 0.05 or 0.1. Defaults to
0, meaning a material has to be directly on the hull to be called stable.
Negative values mean a material has to pull the known hull down by that
amount to count as stable. Few materials lie below the known hull, so only
negative values very close to 0 make sense.
stability_threshold (float | None, optional): Maximum energy above convex hull
for a material to still be considered stable. Usually 0, 0.05 or 0.1.
Defaults to 0, meaning a material has to be directly on the hull to be
called stable. Negative values mean a material has to pull the known hull
down by that amount to count as stable. Few materials lie below the known
hull, so only negative values very close to 0 make sense.
Returns:
tuple[TP, FN, FP, TN]: Indices as pd.Series for true positives,
Expand Down
46 changes: 24 additions & 22 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Plotting functions for analyzing model performance on materials discovery."""

from __future__ import annotations

import math
Expand Down Expand Up @@ -115,23 +117,23 @@ def hist_classified_stable_vs_hull_dist(
),
**kwargs: Any,
) -> plt.Axes | go.Figure:
"""Histogram of the energy difference (either according to DFT ground truth [default]
or model predicted energy) to the convex hull for materials in the WBM data set. The
histogram is broken down into true positives, false negatives, false positives, and
true negatives based on whether the model predicts candidates to be below the known
convex hull. Ideally, in discovery setting a model should exhibit high recall, i.e.
the majority of materials below the convex hull being correctly identified by the
model.
"""Histogram of the energy difference (either according to DFT ground truth - the
default - or the model predicted energy) to the convex hull for materials in the
WBM data set. The histogram is broken down into true positives, false negatives,
false positives, and true negatives based on whether the model predicts candidates
to be below the known convex hull. Ideally, in discovery setting a model should
exhibit high recall, i.e. the majority of materials below the convex hull being
correctly identified by the model.
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
Args:
df (pd.DataFrame): Data frame containing true and predicted hull distances.
each_true_col (str): Name of column with energy above convex hull according to DFT
ground truth (in eV / atom).
each_pred_col (str): Name of column with energy above convex hull predicted by model
(in eV / atom). Same as true energy to convex hull plus predicted minus true
formation energy.
each_true_col (str): Name of column with energy above convex hull according to
DFT ground truth (in eV / atom).
each_pred_col (str): Name of column with energy above convex hull predicted by
model (in eV / atom). Same as true energy to convex hull plus predicted
minus true formation energy.
ax (plt.Axes, optional): matplotlib axes to plot on.
which_energy ('true' | 'pred', optional): Whether to use the true (DFT) hull
distance or the model's predicted hull distance for the histogram.
Expand Down Expand Up @@ -162,7 +164,7 @@ def hist_classified_stable_vs_hull_dist(
each_true, each_pred, stability_threshold
)

# toggle between histogram of DFT-computed or model-predicted distance to convex hull
# switch between histogram of DFT-computed or model-predicted convex hull distance
e_above_hull = df[x_col]
eah_true_pos = e_above_hull[true_pos]
eah_true_neg = e_above_hull[true_neg]
Expand Down Expand Up @@ -335,9 +337,9 @@ def rolling_mae_vs_hull_dist(
to False.
with_sem (bool, optional): If True, plot the standard error of the mean as
shaded area around the rolling MAE. Defaults to True.
show_dft_acc (bool, optional): If True, change color of the triangle of peril's tip
and annotate it with 'Corrected GGA Accuracy' at rolling MAE of 25 meV/atom.
Defaults to False.
show_dft_acc (bool, optional): If True, change color of the triangle of peril's
tip and annotate it with 'Corrected GGA Accuracy' at rolling MAE of 25
meV/atom. Defaults to False.
show_dummy_mae (bool, optional): If True, plot a line at the dummy MAE of always
predicting the target mean.
**kwargs: Additional keyword arguments to pass to df.plot().
Expand Down Expand Up @@ -590,10 +592,10 @@ def cumulative_precision_recall(
project_end_point ('x' | 'y' | 'xy' | '', optional): Whether to project end
points of precision and recall curves to the x/y axis. Defaults to '', i.e. no
axis projection lines.
optimal_recall (str | None, optional): Label for the optimal recall line. Defaults
to 'Optimal Recall'. Set to None to not plot the line.
show_n_stable (bool, optional): Whether to show a horizontal line at the true number
of stable materials. Defaults to True.
optimal_recall (str | None, optional): Label for the optimal recall line.
Defaults to 'Optimal Recall'. Set to None to not plot the line.
show_n_stable (bool, optional): Whether to show a horizontal line at the true
number of stable materials. Defaults to True.
backend ('matplotlib' | 'plotly'], optional): Which plotting engine to use.
Changes the return type. Defaults to 'plotly'.
**kwargs: Keyword arguments passed to df.plot().
Expand Down Expand Up @@ -687,8 +689,8 @@ def cumulative_precision_recall(
ax.plot((0, x_end), (y_end, y_end), **intersect_kwargs)

# optimal recall line finds all stable materials without any false positives
# can be included to confirm all models achieve near optimal recall initially
# and to see how much they overshoot n_stable
# can be included to confirm all models achieve near optimal recall
# initially and to see how much they overshoot n_stable
if optimal_recall and "Recall" in metric:
ax.plot([0, n_stable], [0, 1], color="green", linestyle="--")
ax.text(
Expand Down
2 changes: 1 addition & 1 deletion models/bowsr/test_bowsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
# --time 2h is probably enough but best be safe.
array=f"1-{slurm_array_task_count}%{slurm_max_parallel}",
# --mem 12000 avoids slurmstepd: error: Detected 1 oom-kill event(s)
# Some of your processes may have been killed by the cgroup out-of-memory handler.
# Some of your processes may have been killed by the cgroup out-of-memory handler.
slurm_flags=("--mem", str(12_000)),
# TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
# https://stackoverflow.com/a/40982782
Expand Down
5 changes: 5 additions & 0 deletions models/m3gnet/join_m3gnet_results.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Concatenate M3GNet results from multiple data files generated by slurm job array
into single file.
"""


# %%
from __future__ import annotations

Expand Down
13 changes: 7 additions & 6 deletions models/m3gnet/test_m3gnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""Get M3GNet formation energy predictions on WBM test set.
To slurm submit this file: python path/to/file.py slurm-submit
Requires M3GNet installation: pip install m3gnet
https://github.com/materialsvirtuallab/m3gnet.
"""


# %%
from __future__ import annotations

Expand All @@ -16,12 +23,6 @@
from matbench_discovery.data import as_dict_handler
from matbench_discovery.slurm import slurm_submit

"""
To slurm submit this file: python path/to/file.py slurm-submit
Requires M3GNet installation: pip install m3gnet
https://github.com/materialsvirtuallab/m3gnet
"""

__author__ = "Janosh Riebesell"
__date__ = "2022-08-15"

Expand Down
3 changes: 3 additions & 0 deletions models/m3gnet/wbm_pre_vs_post_m3gnet_relaxation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""Compare M3GNet-relaxed vs DFT-relaxed WBM lattice volumes and angles."""


# %%
import os

Expand Down
13 changes: 7 additions & 6 deletions models/megnet/test_megnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""Get MEGNet formation energy predictions on WBM test set.
To slurm submit this file: python path/to/file.py slurm-submit
Requires MEGNet installation: pip install megnet
See https://github.com/materialsvirtuallab/megnet.
"""


# %%
from __future__ import annotations

Expand All @@ -15,12 +22,6 @@
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit

"""
To slurm submit this file: python path/to/file.py slurm-submit
Requires MEGNet installation: pip install megnet
https://github.com/materialsvirtuallab/megnet
"""

__author__ = "Janosh Riebesell"
__date__ = "2022-11-14"

Expand Down
5 changes: 5 additions & 0 deletions models/voronoi/join_voronoi_features.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Concatenate Voronoi features from multiple data files generated by slurm job array
into single file.
"""


# %%
from __future__ import annotations

Expand Down
3 changes: 3 additions & 0 deletions models/voronoi/train_test_voronoi_rf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""Train and test a Voronoi RandomForestRegressor model."""


# %%
import os
from importlib.metadata import version
Expand Down
5 changes: 5 additions & 0 deletions models/voronoi/voronoi_featurize_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Featurize MP training and WBM test structures with Magpie composition-based and
Voronoi tessellation structure-based features.
"""


# %%
import os
import sys
Expand Down
11 changes: 6 additions & 5 deletions models/wrenformer/test_wrenformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""Download WandB checkpoints for an ensemble of Wrenformer models trained on all MP
formation energies, then makes predictions on some dataset, prints ensemble metrics and
saves predictions to CSV.
"""


# %%
from __future__ import annotations

Expand All @@ -18,11 +24,6 @@
__author__ = "Janosh Riebesell"
__date__ = "2022-08-15"

"""
Download WandB checkpoints for an ensemble of Wrenformer models trained on all MP
formation energies, then makes predictions on some dataset, prints ensemble metrics and
saves predictions to CSV.
"""

task_type = "IS2RE"
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-summary.csv"
Expand Down
7 changes: 3 additions & 4 deletions models/wrenformer/train_wrenformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""Train a Wrenformer ensemble on target_col of data_path."""


# %%
import os
from importlib.metadata import version
Expand All @@ -8,10 +11,6 @@
from matbench_discovery import DEBUG, ROOT, WANDB_PATH, timestamp, today
from matbench_discovery.slurm import slurm_submit

"""
Train a Wrenformer ensemble on target_col of data_path.
"""

__author__ = "Janosh Riebesell"
__date__ = "2022-08-13"

Expand Down
5 changes: 5 additions & 0 deletions scripts/compile_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Compile metrics and total run times for all models and export them to JSON, a
pandas-styled HTML table and a plotly figure.
"""


# %%
from __future__ import annotations

Expand Down
10 changes: 10 additions & 0 deletions scripts/cumulative_clf_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""Plot cumulative precision and/or recall and/or F1 curves for all models into facet
plot with one subplot per metric. Cumulative here means going through the list of WBM
materials ranked by the model's stability prediction starting from the most stable
and updating the precision, recall and F1 score after each new material. This plot
simulates an actual materials screening process and allows practitioners to choose
a cutoff point for the number of DFT calculations they have budget and see which model
will provide the best hit rate for the given budget.
"""


# %%
import pandas as pd
from pymatviz.utils import save_fig
Expand Down
6 changes: 6 additions & 0 deletions scripts/difficult_structures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""Analyze structures and composition with largest mean error across all models.
Maybe there's some chemistry/region of materials space that all models struggle with?
Might point to deficiencies in the data or models architecture.
"""


# %%
import matplotlib.pyplot as plt
import pandas as pd
Expand Down
16 changes: 8 additions & 8 deletions scripts/hist_classified_stable_vs_hull_dist.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""Histogram of the energy difference (either according to DFT ground truth [default] or
model predicted energy) to the convex hull for materials in the WBM data set. The
histogram stacks true/false positives/negatives with different colors.
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
"""


# %%
from typing import Final

Expand All @@ -11,14 +19,6 @@
__author__ = "Rhys Goodall, Janosh Riebesell"
__date__ = "2022-06-18"

"""
Histogram of the energy difference (either according to DFT ground truth [default] or
model predicted energy) to the convex hull for materials in the WBM data set. The
histogram stacks true/false positives/negatives with different colors.
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
"""


# %%
model_name = "Wrenformer"
Expand Down
16 changes: 8 additions & 8 deletions scripts/hist_classified_stable_vs_hull_dist_batches.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""Histogram of the energy difference (either according to DFT ground truth [default] or
model predicted energy) to the convex hull for materials in the WBM data set. The
histogram stacks true/false positives/negatives with different colors.
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
"""


# %%
from typing import Final

Expand All @@ -8,14 +16,6 @@
from matbench_discovery.plots import hist_classified_stable_vs_hull_dist
from matbench_discovery.preds import df_wbm, e_form_col, each_pred_col, each_true_col

"""
Histogram of the energy difference (either according to DFT ground truth [default] or
model predicted energy) to the convex hull for materials in the WBM data set. The
histogram stacks true/false positives/negatives with different colors.
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
"""

__author__ = "Rhys Goodall, Janosh Riebesell"
__date__ = "2022-08-25"

Expand Down
Loading

0 comments on commit 1ae72c0

Please sign in to comment.