diff --git a/DenseNet/plot_results.py b/DenseNet/plot_results.py index 71288c3..8d33680 100644 --- a/DenseNet/plot_results.py +++ b/DenseNet/plot_results.py @@ -3,18 +3,35 @@ import numpy as np -def plot_cifar10(): +def plot_cifar10(save=True): - with open("experiment_log_cifar10.json", "r") as f: + with open("./log/experiment_log_cifar10.json", "r") as f: d = json.load(f) - train_accuracy = 100 * (1 - np.array(d["train_loss"])[:, 1]) - test_accuracy = 100 * (1 - np.array(d["test_loss"])[:, 1]) + train_accuracy = 100 * (np.array(d["train_loss"])[:, 1]) + test_accuracy = 100 * (np.array(d["test_loss"])[:, 1]) - plt.plot(train_accuracy, color="tomato", linewidth=2) - plt.plot(test_accuracy, color="steelblue", linewidth=2) + fig = plt.figure() + ax1 = fig.add_subplot(111) + ax1.set_ylabel('Accuracy') + ax1.plot(train_accuracy, color="tomato", linewidth=2, label='train_acc') + ax1.plot(test_accuracy, color="steelblue", linewidth=2, label='test_acc') + ax1.legend(loc=0) + + train_loss = np.array(d["train_loss"])[:, 0] + test_loss = np.array(d["test_loss"])[:, 0] + + ax2 = ax1.twinx() + ax2.set_ylabel('Loss') + ax2.plot(train_loss, '--', color="tomato", linewidth=2, label='train_loss') + ax2.plot(test_loss, '--', color="steelblue", linewidth=2, label='test_loss') + ax2.legend(loc=1) + + ax1.grid(True) + + if save: + fig.savefig('./figures/plot_cifar10.svg') - plt.grid() plt.show() plt.clf() plt.close()