-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
clarify naming scheme for gridsearch logs
- Loading branch information
1 parent
acd3f48
commit 217649a
Showing
4 changed files
with
180 additions
and
1 deletion.
There are no files selected for viewing
176 changes: 176 additions & 0 deletions
176
prediction/short_term_outcome_prediction/cluster/gridsearch_evaluation.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "initial_id", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-12T11:45:30.320074Z", | ||
"start_time": "2024-08-12T11:45:30.316606Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"import pandas as pd\n", | ||
"import os\n", | ||
"import seaborn as sns" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "801ca5fa6e5486ff", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-12T11:45:46.031267Z", | ||
"start_time": "2024-08-12T11:45:46.027641Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"log_folder_path = '/Users/jk1/temp/opsum_end/training/hyperopt/gridsearch'\n", | ||
"output_dir = '/Users/jk1/Downloads'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "38525202312f478", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-12T11:47:45.683875Z", | ||
"start_time": "2024-08-12T11:47:45.628982Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# find all jsonl files in log_folder_path\n", | ||
"gs_df = pd.DataFrame()\n", | ||
"for root, dirs, files in os.walk(log_folder_path):\n", | ||
" for file in files:\n", | ||
" if file.endswith('.jsonl'):\n", | ||
" temp_df = pd.read_json(os.path.join(root, file), \n", | ||
" lines=True, dtype={'timestamp': 'object'}, convert_dates=False).drop(0)\n", | ||
" gs_df = pd.concat([gs_df, temp_df], ignore_index=True)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "4168d945d95ab438", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-12T11:47:47.651657Z", | ||
"start_time": "2024-08-12T11:47:47.626756Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"gs_df" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "76e926edc7bbaba", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-12T11:48:19.824527Z", | ||
"start_time": "2024-08-12T11:48:19.807887Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# find best by median_val_scores\n", | ||
"best_df = gs_df.sort_values('median_val_scores', ascending=False).head(1)\n", | ||
"best_df" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f5575b1ac9754ede", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-12T11:48:42.067654Z", | ||
"start_time": "2024-08-12T11:48:41.757260Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# plot histogram of median_val_scores\n", | ||
"ax = sns.histplot(x='median_val_scores', data=gs_df)\n", | ||
"ax.figure.set_size_inches(10,10)\n", | ||
"ax.set_title('Median validation scores')\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "4ad2eb1e6da667eb", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-12T11:52:12.124761Z", | ||
"start_time": "2024-08-12T11:52:09.989603Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# plot a grid with all previous plots\n", | ||
"fig, axes = plt.subplots(4, 3, figsize=(25, 25))\n", | ||
"sns.boxplot(x='num_layers', y='median_val_scores', data=gs_df, ax=axes[0,0])\n", | ||
"sns.boxplot(x='batch_size', y='median_val_scores', data=gs_df, ax=axes[1,0])\n", | ||
"sns.boxplot(x='num_head', y='median_val_scores', data=gs_df, ax=axes[1,2])\n", | ||
"sns.regplot(x='dropout', y='median_val_scores', data=gs_df, ax=axes[2,0])\n", | ||
"sns.regplot(x='train_noise', y='median_val_scores', data=gs_df, logx=True, ax=axes[2,1])\n", | ||
"# set x scale to log for train noise plot\n", | ||
"axes[2,1].set_xscale('log')\n", | ||
"sns.scatterplot(x='lr', y='median_val_scores', data=gs_df, ax=axes[2,2])\n", | ||
"sns.scatterplot(x='weight_decay', y='median_val_scores', data=gs_df, ax=axes[0,2])\n", | ||
"# set x limits to 0, 0.1 for weight decay plot\n", | ||
"# axes[0,2].set_xlim(0, 0.0002)\n", | ||
"sns.scatterplot(x='grad_clip_value', y='median_val_scores', data=gs_df, ax=axes[3,0])\n", | ||
"# set x limits to 0, 0.5 for grad_clip_value plot\n", | ||
"# axes[3,0].set_xlim(0, 0.5)\n", | ||
"\n", | ||
"# # set y limits to 0.88, 0.92 for all plots\n", | ||
"# for ax in axes.flat:\n", | ||
"# ax.set_ylim(0.88, 0.915)\n", | ||
"\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a8198a950e7b6945", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters