-
Notifications
You must be signed in to change notification settings - Fork 1
/
viz.py
100 lines (84 loc) · 2.94 KB
/
viz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
max_rows = 10
def plot_algorithm_result(res, filename=None):
n_k = len(res)
fig, axs = plt.subplots(1, n_k, figsize=(21, 9))
for ax, (k, v) in zip(axs, res.items()):
ax.set_title(k)
ax.plot(v, ".-")
if filename:
axs[0].set_title(Path(filename).parts[-2])
fig.savefig(filename, bbox_inches="tight")
plt.close(fig)
def plot_mean_std_1d(mean, std, filename):
fig, ax = plt.subplots(2)
ax[0].plot(mean, "k")
ax[0].plot(mean + std, "k--")
ax[0].plot(mean - std, "k--")
# ax[1].plot(ent, "r")
if filename:
ax[0].set_title(Path(filename).parts[-2])
fig.savefig(filename, bbox_inches="tight")
plt.close(fig)
def plot_samples(samples, filename=None):
fig, ax = plt.subplots()
ax.plot(samples, "k", alpha=0.1)
if filename:
ax.set_title(Path(filename).parts[-2])
fig.savefig(filename, bbox_inches="tight")
plt.close(fig)
def plot_policy_samples(policy, n_samples, d_viz=6, filename=None):
actions, _ = policy(n_samples)
d_s, d_t, d_a = actions.shape
d = min(d_a, d_viz)
fig, axs = plt.subplots(d, figsize=(12, 9))
axs = [axs] if d == 1 else axs
for i, ax in enumerate(axs):
ax.plot(actions[:, :, i].T, ".-", alpha=0.3)
if filename:
axs[0].set_title(Path(filename).parts[-2])
fig.savefig(filename, bbox_inches="tight")
plt.close(fig)
def plot_sequence(sequence, d_viz=10, filename=None):
d_t, d_s = sequence.shape
if d_viz is None:
fig, ax = plt.subplots(figsize=(12, 9))
ax.plot(sequence)
axs = [ax]
else:
d = min(d_s, d_viz)
fig, axs = plt.subplots(d, figsize=(12, 9))
axs = [axs] if d == 1 else axs
for i, ax in enumerate(axs):
ax.plot(sequence[:, i], ".-")
if filename:
axs[0].set_title(Path(filename).parts[-2])
fig.savefig(filename, bbox_inches="tight")
plt.close(fig)
def plot_sequence_history(sequence, sequence_history, d_viz=20, filename=None):
(d_t_,) = sequence.shape
d_t, d_s, d_p = sequence_history.shape
assert d_t == d_t_
d_s = min(d_s, d_viz)
fig, ax = plt.subplots(figsize=(12, 9))
for t in range(d_t):
for i in range(d_s):
ax.plot(np.arange(t, t + d_p), sequence_history[t, i, :], alpha=0.1)
ax.plot(sequence, "k.-")
if filename:
ax.set_title(Path(filename).parts[-2])
fig.savefig(filename, bbox_inches="tight")
plt.close(fig)
def plot_smoothness(spectrum, frequency, signal, filename=None):
fig, ax = plt.subplots(1, 2, figsize=(12, 9))
ax[0].plot(signal)
ax[1].plot(frequency, spectrum)
ax[0].set_xlabel("Timesteps")
ax[0].set_ylabel("Action Norm")
ax[1].set_xlabel("Frequency")
ax[1].set_ylabel("Spectrum")
if filename:
fig.savefig(filename, bbox_inches="tight")
plt.close(fig)