Skip to content

Commit

Permalink
Improved formatting of minibatch_cvxlin_timing.py plots.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexshtf committed Aug 28, 2023
1 parent 5b0a41c commit 7d6ef86
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions experiments/minibatch_cvxlin_timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,23 @@ def run_experiment(desc: ExperimentDesc):
df = pd.DataFrame.from_records(results)
df.to_csv('minibatch_cvxlin_timing.csv')

plt.figure(figsize=(16, 12))
df = pd.read_csv('minibatch_cvxlin_timing.csv')
df = pd.read_csv('minibatch_cvxlin_timing.csv') \
.rename(columns={'prox_pt_time': 'Prox-PT', 'sgd_time': 'SGD',
'batch_size': 'Batch Size'})

sns.set(context='paper', palette='Set1', style='ticks', font_scale=1.5)
sns.lmplot(data=df, x='sgd_time', y='prox_pt_time', hue='type',
col='dim', row='batch_size', palette="Set1", ci=None, facet_kws=dict(sharex=False, sharey=False))
g = sns.FacetGrid(data=df[df['type'] == 'Least squares'], col='dim', row='Batch Size', palette="Set1",
sharex=False, sharey=False, margin_titles=True, despine=False,
height=3, aspect=1.2)
plt.suptitle('Least squares proximal point vs. SGD running time (seconds)')
g.map_dataframe(sns.regplot, x='SGD', y='Prox-PT', ci=None)
g.add_legend()
plt.show()

g = sns.FacetGrid(data=df[df['type'] == 'Logistic regression'], col='dim', row='Batch Size', palette="Set1",
sharex=False, sharey=False, margin_titles=True, despine=False,
height=3, aspect=1.2)
plt.suptitle('Logistic regression proximal point vs. SGD running time (seconds)')
g.map_dataframe(sns.regplot, x='SGD', y='Prox-PT', ci=None)
g.add_legend()
plt.show()

0 comments on commit 7d6ef86

Please sign in to comment.