-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_train_progress.py
159 lines (123 loc) · 3.83 KB
/
plot_train_progress.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import sys
import csv
import glob
import warnings
import numpy as np
from matplotlib import pyplot as plt
def ma(data, window, *, mode='valid'):
"""Moving average filter
Args:
data (numpy array): The data to be smoothed
window (int): The window used to smooth the data
mode (str): The mode used to smooth the data
Returns:
data_smoothed (numpy array): The data smoothed by the moving average
"""
N = len(data)
data_smoothed = np.zeros(N)
mask = np.ones(window)/window
data_smoothed = np.convolve(data, mask, mode)
return data_smoothed
def get_reward_from_csv(file):
"""Extract reward vector from a csv file"""
reward = []
print("Reading reward from {}".format(file))
with open(file, 'rt') as f:
data = csv.reader(f)
for ri, row in enumerate(data):
if ri < 2:
# Skip first two rows
continue
if len(row) == 0:
# Skip empty rows
continue
"""
XXX ajs 11/Sep/2019 There seems to be bug with
csv.DictWriter.writerow() as used in
stable_baselines/bench/monitor.py:98 - we occasionally see reward
values that are missing the mantissa, e.g.
'XXX.XXe-5' will simply appear as 'e-5' in the CSV cell.
To work-around this, we detect these cases and replace with 0
"""
if row[0].split("e")[0] == '':
warnings.warn(
"Got malformed reward (no mantissa) "
"at line {} of {}: {}, replacing with 0.0".format(
ri,
file,
row[0]
)
)
reward.append(0.0)
continue
reward.append(float(row[0]))
return np.array(reward, dtype=float)
def plot_csv_glob(fileglob, window, *, quartiles=[10, 50, 90], **kwargs):
"""Plot performance over many seeds into the current axes
Args:
fileglob (str): File path glob matching one or more monitor.csv files
window (int): Moving average window
kwargs (dict): Plotting keyword args
"""
files = list(glob.glob(fileglob))
if len(files) == 0:
warnings.warn("Glob '{}' matched 0 files".format(fileglob))
sys.exit(-1)
print("Loading rewards from {} files".format(len(files)))
rewards = [
get_reward_from_csv(f)
for f in files
]
reward_len = [
len(r)
for r in rewards
]
min_len = min(reward_len)
max_len = max(reward_len)
print("Min training length: {}, max: {}".format(
min_len,
max_len
))
# Crop all rewards to the minimum length
rewards = np.array([
r[0:min_len]
for r in rewards
], dtype=float)
x = np.arange(0, rewards.shape[1])
# Smooth
q1, q2, q3 = (
ma(q, window)
for q in np.percentile(rewards, q=quartiles, axis=0)
)
x = np.arange(window // 2, window // 2 + len(q1))
p0 = plt.plot(x, q2, **kwargs)
plt.fill_between(
x,
q1,
q3,
color=p0[0].get_color(),
alpha=0.1,
lw=0
)
return rewards
def main(fileglob, window, **kwargs):
"""Plot model training progress"""
# Ensure we are using a GUI frontend so X-Forwarding works
import matplotlib
matplotlib.use('tkagg')
plt.figure()
r = plot_csv_glob(fileglob, window, **kwargs)
plt.title("{} ({} files)".format(fileglob, len(r)))
plt.xlabel("1e3 timesteps")
plt.ylabel("Reward")
plt.ylim(-50, 1050)
plt.xlim(0, 6e6 / 1e3)
plt.grid()
plt.tight_layout()
plt.show()
plt.close()
if __name__ == '__main__':
fileglob = sys.argv[1]
window = int(sys.argv[2])
main(fileglob, window)