Skip to content

Commit

Permalink
clarify naming scheme for gridsearch logs
Browse files Browse the repository at this point in the history
  • Loading branch information
JulianKlug committed Aug 12, 2024
1 parent acd3f48 commit 217649a
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:45:30.320074Z",
"start_time": "2024-08-12T11:45:30.316606Z"
}
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import os\n",
"import seaborn as sns"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "801ca5fa6e5486ff",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:45:46.031267Z",
"start_time": "2024-08-12T11:45:46.027641Z"
}
},
"outputs": [],
"source": [
"log_folder_path = '/Users/jk1/temp/opsum_end/training/hyperopt/gridsearch'\n",
"output_dir = '/Users/jk1/Downloads'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38525202312f478",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:47:45.683875Z",
"start_time": "2024-08-12T11:47:45.628982Z"
}
},
"outputs": [],
"source": [
"# find all jsonl files in log_folder_path\n",
"gs_df = pd.DataFrame()\n",
"for root, dirs, files in os.walk(log_folder_path):\n",
" for file in files:\n",
" if file.endswith('.jsonl'):\n",
" temp_df = pd.read_json(os.path.join(root, file), \n",
" lines=True, dtype={'timestamp': 'object'}, convert_dates=False).drop(0)\n",
" gs_df = pd.concat([gs_df, temp_df], ignore_index=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4168d945d95ab438",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:47:47.651657Z",
"start_time": "2024-08-12T11:47:47.626756Z"
}
},
"outputs": [],
"source": [
"gs_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76e926edc7bbaba",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:48:19.824527Z",
"start_time": "2024-08-12T11:48:19.807887Z"
}
},
"outputs": [],
"source": [
"# find best by median_val_scores\n",
"best_df = gs_df.sort_values('median_val_scores', ascending=False).head(1)\n",
"best_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5575b1ac9754ede",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:48:42.067654Z",
"start_time": "2024-08-12T11:48:41.757260Z"
}
},
"outputs": [],
"source": [
"# plot histogram of median_val_scores\n",
"ax = sns.histplot(x='median_val_scores', data=gs_df)\n",
"ax.figure.set_size_inches(10,10)\n",
"ax.set_title('Median validation scores')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4ad2eb1e6da667eb",
"metadata": {
"ExecuteTime": {
"end_time": "2024-08-12T11:52:12.124761Z",
"start_time": "2024-08-12T11:52:09.989603Z"
}
},
"outputs": [],
"source": [
"# plot a grid with all previous plots\n",
"fig, axes = plt.subplots(4, 3, figsize=(25, 25))\n",
"sns.boxplot(x='num_layers', y='median_val_scores', data=gs_df, ax=axes[0,0])\n",
"sns.boxplot(x='batch_size', y='median_val_scores', data=gs_df, ax=axes[1,0])\n",
"sns.boxplot(x='num_head', y='median_val_scores', data=gs_df, ax=axes[1,2])\n",
"sns.regplot(x='dropout', y='median_val_scores', data=gs_df, ax=axes[2,0])\n",
"sns.regplot(x='train_noise', y='median_val_scores', data=gs_df, logx=True, ax=axes[2,1])\n",
"# set x scale to log for train noise plot\n",
"axes[2,1].set_xscale('log')\n",
"sns.scatterplot(x='lr', y='median_val_scores', data=gs_df, ax=axes[2,2])\n",
"sns.scatterplot(x='weight_decay', y='median_val_scores', data=gs_df, ax=axes[0,2])\n",
"# set x limits to 0, 0.1 for weight decay plot\n",
"# axes[0,2].set_xlim(0, 0.0002)\n",
"sns.scatterplot(x='grad_clip_value', y='median_val_scores', data=gs_df, ax=axes[3,0])\n",
"# set x limits to 0, 0.5 for grad_clip_value plot\n",
"# axes[3,0].set_xlim(0, 0.5)\n",
"\n",
"# # set y limits to 0.88, 0.92 for all plots\n",
"# for ax in axes.flat:\n",
"# ax.set_ylim(0.88, 0.915)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8198a950e7b6945",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def launch_cluster_gridsearch(data_splits_path: str, output_folder: str,
with open(path.join(output_folder, 'gridsearch_config.json'), 'w') as f:
json.dump(gridsearch_config, f)


# REDIS Setup for SLURM/optuna ref: https://github.com/liukidar/stune
if storage_pwd is not None and storage_port is not None:
storage = optuna.storages.JournalStorage(optuna.storages.JournalRedisStorage(
url=f'redis://default:{storage_pwd}@{storage_host}:{storage_port}/opsum'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def get_score(trial, ds, data_splits_path, output_folder, gridsearch_config:dict
d['split_file'] = data_splits_path
text = json.dumps(d)
text += '\n'
dest = path.join(output_folder, 'gridsearch.jsonl')
dest = path.join(output_folder, f'{os.path.basename(output_folder)}_gridsearch.jsonl')
with open(dest, 'a') as handle:
handle.write(text)
print("WRITTEN in ", dest)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ terminado==0.9.4
tf-estimator-nightly==2.8.0.dev2021122109
torch==1.11.0
xgboost==1.6.1
pytorch_lightning

modun @ git+git://github.com/Jimmy2027/MODUN@b106ed60d97bcab0dc629030447c73011d06a3f9

0 comments on commit 217649a

Please sign in to comment.