Skip to content

Commit

Permalink
eda_mp_trj.py add code for mp-trj-forces-ptable-hists.pdf showing the…
Browse files Browse the repository at this point in the history
… distribution of forces for each element in the periodic table

per_element_errors.py add ptable-element-wise-each-error-hists-{model}.pdf
  • Loading branch information
janosh committed Dec 3, 2023
1 parent 46366d1 commit 62a6458
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 21 deletions.
100 changes: 81 additions & 19 deletions data/mp/eda_mp_trj.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
# %%
import io
import os
from typing import Any
from zipfile import ZipFile

import ase
import ase.io.extxyz
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
Expand Down Expand Up @@ -37,6 +40,7 @@
e_form_per_atom_col = "ef_per_atom"
magmoms_col = "magmoms"
forces_col = "forces"
elems_col = "symbols"


# %% load MP element counts by occurrence to compute ratio with MPtrj
Expand All @@ -46,7 +50,7 @@
df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index(id_col)


# %% --- load preprocessed MPtrj summary data ---
# %% --- load preprocessed MPtrj summary data if available ---
mp_trj_summary_path = f"{DATA_DIR}/mp/mp-trj-2022-09-summary.json.bz2"
if os.path.isfile(mp_trj_summary_path):
df_mp_trj = pd.read_json(mp_trj_summary_path)
Expand Down Expand Up @@ -84,9 +88,9 @@

df_mp_trj = pd.DataFrame(
{
info_to_id(atoms.info): {"formula": str(atoms.symbols)}
info_to_id(atoms.info): atoms.info
| {key: atoms.arrays.get(key) for key in ("forces", "magmoms")}
| atoms.info
| {"formula": str(atoms.symbols), elems_col: atoms.symbols}
for atoms_list in tqdm(mp_trj_atoms.values(), total=len(mp_trj_atoms))
for atoms in atoms_list
}
Expand All @@ -106,41 +110,97 @@
df_mp_trj.to_json(mp_trj_summary_path)


# %%
def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
"""Annotate each periodic table tile with the number of values in its histogram."""
facecolor = cmap(norm(np.sum(len(hist_vals)))) if hist_vals else "none"
bbox = dict(facecolor=facecolor, alpha=0.4, pad=2, edgecolor="none")
return dict(text=si_fmt(len(hist_vals), ".0f"), bbox=bbox)


# %% plot per-element magmom histograms
magmom_hist_path = f"{DATA_DIR}/mp/mp-trj-2022-09-elem-magmoms.json.bz2"

if os.path.isfile(magmom_hist_path):
mp_trj_elem_magmoms = pd.read_json(magmom_hist_path, typ="series")
elif "mp_trj_elem_magmoms" not in locals():
df_mp_trj_magmom = pd.DataFrame(
{
info_to_id(atoms.info): (
dict(zip(atoms.symbols, atoms.arrays["magmoms"], strict=True))
if magmoms_col in atoms.arrays
else None
)
for frame_id in tqdm(mp_trj_atoms)
for atoms in mp_trj_atoms[frame_id]
}
).T.dropna(axis=0, how="all")
# project magmoms onto symbols in dict
df_mp_trj_elem_magmom = pd.DataFrame(
[
dict(zip(elems, magmoms))
for elems, magmoms in df_mp_trj.set_index(elems_col)[magmoms_col]
.dropna()
.items()
]
)

mp_trj_elem_magmoms = {
col: list(df_mp_trj_magmom[col].dropna()) for col in df_mp_trj_magmom
col: list(df_mp_trj_elem_magmom[col].dropna()) for col in df_mp_trj_elem_magmom
}
pd.Series(mp_trj_elem_magmoms).to_json(magmom_hist_path)

cmap = plt.cm.get_cmap("viridis")
norm = matplotlib.colors.LogNorm(vmin=1, vmax=150_000)

ax = ptable_hists(
mp_trj_elem_magmoms,
symbol_pos=(0.2, 0.8),
log=True,
cbar_title="Magmoms ($μ_B$)",
cbar_title_kwds=dict(fontsize=16),
cbar_coords=(0.18, 0.85, 0.42, 0.02),
# annotate each element with its number of magmoms in MPtrj
anno_kwds=dict(text=lambda hist_vals: si_fmt(len(hist_vals), ".0f")),
anno_kwds=tile_count_anno,
)

cbar_ax = ax.figure.add_axes([0.26, 0.78, 0.25, 0.015])
cbar = matplotlib.colorbar.ColorbarBase(
cbar_ax, cmap=cmap, norm=norm, orientation="horizontal"
)
save_fig(ax, f"{PDF_FIGS}/mp-trj-magmoms-ptable-hists.pdf")


# %% plot per-element force histograms
force_hist_path = f"{DATA_DIR}/mp/mp-trj-2022-09-elem-forces.json.bz2"

if os.path.isfile(force_hist_path):
mp_trj_elem_forces = pd.read_json(force_hist_path, typ="series")
elif "mp_trj_elem_forces" not in locals():
df_mp_trj_elem_forces = pd.DataFrame(
[
dict(zip(elems, np.abs(forces).mean(axis=1)))
for elems, forces in df_mp_trj.set_index(elems_col)[forces_col].items()
]
)
mp_trj_elem_forces = {
col: list(df_mp_trj_elem_forces[col].dropna()) for col in df_mp_trj_elem_forces
}
mp_trj_elem_forces = pd.Series(mp_trj_elem_forces)
mp_trj_elem_forces.to_json(force_hist_path)

cmap = plt.cm.get_cmap("viridis")
norm = matplotlib.colors.LogNorm(vmin=1, vmax=1_000_000)

max_force = 10 # eV/Å
ax = ptable_hists(
mp_trj_elem_forces.copy().map(lambda x: [val for val in x if val < max_force]),
symbol_pos=(0.3, 0.8),
log=True,
cbar_title="1/3 Σ|Forces| (eV/Å)",
cbar_title_kwds=dict(fontsize=16),
cbar_coords=(0.18, 0.85, 0.42, 0.02),
x_range=(0, max_force),
anno_kwds=tile_count_anno,
)

cbar_ax = ax.figure.add_axes([0.26, 0.78, 0.25, 0.015])
cbar = matplotlib.colorbar.ColorbarBase(
cbar_ax, cmap=cmap, norm=norm, orientation="horizontal"
)

save_fig(ax, f"{PDF_FIGS}/mp-trj-forces-ptable-hists.pdf")


# %%
elem_counts: dict[str, dict[str, int]] = {}
for count_mode in ("composition", "occurrence"):
Expand All @@ -153,9 +213,11 @@


# %%
count_mode = "composition"
if "trj_elem_counts" not in locals():
trj_elem_counts = pd.read_json(
f"{data_page}/mp-trj-element-counts-by-occurrence.json", typ="series"
f"{data_page}/mp-trj-element-counts-by-{count_mode}.json",
typ="series",
)

excl_elems = "He Ne Ar Kr Xe".split() if (excl_noble := False) else ()
Expand All @@ -167,12 +229,12 @@
zero_color="#efefef",
log=(log := True),
exclude_elements=excl_elems, # drop noble gases
cbar_range=None if excl_noble else (2000, None),
cbar_range=None if excl_noble else (10_000, None),
label_font_size=17,
value_font_size=14,
)

img_name = f"mp-trj-element-counts-by-occurrence{'-log' if log else ''}"
img_name = f"mp-trj-element-counts-by-{count_mode}{'-log' if log else ''}"
if excl_noble:
img_name += "-excl-noble"
save_fig(ax_ptable, f"{PDF_FIGS}/{img_name}.pdf")
Expand Down
2 changes: 1 addition & 1 deletion matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
global_layout = dict(
# colorway=px.colors.qualitative.Pastel,
# colorway=colorway,
margin=dict(l=30, r=20, t=60, b=20),
# margin=dict(l=30, r=20, t=60, b=20),
paper_bgcolor="rgba(0,0,0,0)",
# plot_bgcolor="rgba(0,0,0,0)",
font_size=13,
Expand Down
16 changes: 15 additions & 1 deletion scripts/model_figs/per_element_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
import plotly.express as px
from pymatgen.core import Composition, Element
from pymatviz import ptable_heatmap_plotly
from pymatviz import ptable_heatmap_plotly, ptable_hists
from pymatviz.io import save_fig
from pymatviz.utils import bin_df_cols, df_ptable
from tqdm import tqdm
Expand Down Expand Up @@ -256,3 +256,17 @@

fig.show()
save_fig(fig, f"{SITE_FIGS}/each-error-vs-least-prevalent-element-in-struct.svelte")


# %% plot histogram of model errors for each element
model = "MACE"
ax = ptable_hists(
df_frac_comp * (df_each_err[model].to_numpy()[:, None]),
log=True,
cbar_title=f"{model} convex hull distance errors (eV/atom)",
x_range=(-0.5, 0.5),
symbol_pos=(0.1, 0.8),
)

img_name = f"ptable-each-error-hists-{model.lower().replace(' ', '-')}"
save_fig(ax, f"{PDF_FIGS}/{img_name}.pdf")

0 comments on commit 62a6458

Please sign in to comment.