-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_loss_figure.py
40 lines (33 loc) · 1.22 KB
/
gen_loss_figure.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import argparse
import matplotlib.pyplot as plt
import numpy as np
def plot_loss(task, experiment):
log_file_path = "models/{}/{}.log".format(experiment, task)
steps_list, acc_list = [], []
with open(log_file_path, "r") as file:
for line in file:
line_split = line.split()
if ("Step" in line_split):
step_num = line_split[4].split('/')[0]
acc = line_split[6][:-1]
if (int(step_num) not in steps_list):
steps_list.append(int(step_num))
acc_list.append(float(acc))
savepath = "output/{}/{}-loss.png".format(experiment, experiment)
plt.plot(steps_list, acc_list)
plt.title(experiment)
plt.xlabel("Train steps")
plt.ylabel("Accuracy")
plt.yticks(np.arange(0, 100, 10))
plt.savefig(savepath)
def main(config):
experiment = config.experiment.lower()
task = config.task.lower()
plot_loss(task, experiment)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default="twitter_conv")
parser.add_argument('--experiment', type=str, default="twitter_conv_0")
config = parser.parse_args()
main(config)