Skip to content

Commit

Permalink
Improved plot visualisation.
Browse files Browse the repository at this point in the history
- Fixed problem showing the accuracy as if it was error rate
- Added loss in additional axis
- Added option to export the plot
  • Loading branch information
perellonieto committed Nov 10, 2017
1 parent b75f09f commit e8de9b1
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions DenseNet/plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,35 @@
import numpy as np


def plot_cifar10():
def plot_cifar10(save=True):

with open("experiment_log_cifar10.json", "r") as f:
with open("./log/experiment_log_cifar10.json", "r") as f:
d = json.load(f)

train_accuracy = 100 * (1 - np.array(d["train_loss"])[:, 1])
test_accuracy = 100 * (1 - np.array(d["test_loss"])[:, 1])
train_accuracy = 100 * (np.array(d["train_loss"])[:, 1])
test_accuracy = 100 * (np.array(d["test_loss"])[:, 1])

plt.plot(train_accuracy, color="tomato", linewidth=2)
plt.plot(test_accuracy, color="steelblue", linewidth=2)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.set_ylabel('Accuracy')
ax1.plot(train_accuracy, color="tomato", linewidth=2, label='train_acc')
ax1.plot(test_accuracy, color="steelblue", linewidth=2, label='test_acc')
ax1.legend(loc=0)

train_loss = np.array(d["train_loss"])[:, 0]
test_loss = np.array(d["test_loss"])[:, 0]

ax2 = ax1.twinx()
ax2.set_ylabel('Loss')
ax2.plot(train_loss, '--', color="tomato", linewidth=2, label='train_loss')
ax2.plot(test_loss, '--', color="steelblue", linewidth=2, label='test_loss')
ax2.legend(loc=1)

ax1.grid(True)

if save:
fig.savefig('./figures/plot_cifar10.svg')

plt.grid()
plt.show()
plt.clf()
plt.close()
Expand Down

0 comments on commit e8de9b1

Please sign in to comment.