Skip to content

Commit

Permalink
merge load_model_preds() into load_df_wbm_with_preds() and pass throu…
Browse files Browse the repository at this point in the history
…gh **kwargs to pandas reader

fix test_load_preds.py and test_plots.py
  • Loading branch information
janosh committed Dec 9, 2022
1 parent b0d7cca commit 97f2cc6
Show file tree
Hide file tree
Showing 19 changed files with 1,538,790 additions and 56 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ job-logs/

# slurm logs
*slurm-*.log
models/**/*.csv

# temporary ignore rules
paper
Expand Down
34 changes: 22 additions & 12 deletions matbench_discovery/load_preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@


def glob_to_df(
pattern: str, reader: Callable[[Any], pd.DataFrame] = None, pbar: bool = True
pattern: str,
reader: Callable[[Any], pd.DataFrame] = None,
pbar: bool = True,
**kwargs: Any,
) -> pd.DataFrame:
"""Combine data files matching a glob pattern into a single dataframe.
Expand All @@ -31,6 +34,7 @@ def glob_to_df(
reader (Callable[[Any], pd.DataFrame], optional): Function that loads data from
disk. Defaults to pd.read_csv if ".csv" in pattern else pd.read_json.
pbar (bool, optional): Whether to show progress bar. Defaults to True.
**kwargs: Keyword arguments passed to reader (i.e. pd.read_csv or pd.read_json).
Returns:
pd.DataFrame: Combined dataframe.
Expand All @@ -44,27 +48,34 @@ def glob_to_df(

sub_dfs = {} # used to join slurm job array results into single df
for file in tqdm(files, disable=not pbar):
df = reader(file)
df = reader(file, **kwargs)
sub_dfs[file] = df

return pd.concat(sub_dfs.values())


def load_model_preds(
models: Sequence[str], pbar: bool = True, id_col: str = "material_id"
) -> dict[str, pd.DataFrame]:
"""Load model predictions from disk into dictionary of dataframes.
def load_df_wbm_with_preds(
models: Sequence[str],
pbar: bool = True,
id_col: str = "material_id",
return_model_dfs: bool = False,
**kwargs: Any,
) -> pd.DataFrame | dict[str, pd.DataFrame]:
"""Load WBM summary dataframe with model predictions from disk.
Args:
models (Sequence[str]): Model names must be keys of data_paths dict.
pbar (bool, optional): Whether to show progress bar. Defaults to True.
id_col (str, optional): Column to set as df.index. Defaults to "material_id".
return_model_dfs (bool, optional): Whether to return dict of dataframes for each
model dfs. Defaults to False.
**kwargs: Keyword arguments passed to glob_to_df().
Raises:
ValueError: On unknown model names.
Returns:
dict[str, pd.DataFrame]: Dictionary of dataframes, one for each model.
pd.DataFrame: WBM summary dataframe with model predictions.
"""
if mismatch := ", ".join(set(models) - set(DATA_PATHS)):
raise ValueError(f"Unknown models: {mismatch}")
Expand All @@ -74,14 +85,12 @@ def load_model_preds(
for model_name in (bar := tqdm(models, disable=not pbar)):
bar.set_description(model_name)
pattern = DATA_PATHS[model_name]
df = glob_to_df(pattern, pbar=False).set_index(id_col)
df = glob_to_df(pattern, pbar=False, **kwargs).set_index(id_col)
dfs[model_name] = df

return dfs
if return_model_dfs:
return dfs


def load_df_wbm_with_preds(**kwargs: Any) -> pd.DataFrame:
dfs = load_model_preds(**kwargs)
df_out = df_wbm.copy()
for model_name, df in dfs.items():
model_key = model_name.lower().replace(" ", "_")
Expand All @@ -104,4 +113,5 @@ def load_df_wbm_with_preds(**kwargs: Any) -> pd.DataFrame:
raise ValueError(
f"No pred col for {model_name=}, available cols={list(df)}"
)

return df_out
Loading

0 comments on commit 97f2cc6

Please sign in to comment.