-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_script.py
65 lines (52 loc) · 2.32 KB
/
test_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/python
#
# Notes for interactive testing in an IPython session.
#
# For the main scripts, see `main.py` (model training)
# and `anim.py` (create animations from completed training).
import gc
import importlib
import sys
import matplotlib.pyplot as plt
import tensorflow as tf
from randomthought import plotter
from randomthought import util
from . import main
from . import config
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images = util.preprocess_images(train_images)
test_images = util.preprocess_images(test_images)
# Be aware that as of this writing, (January 2023, tf 2.12-nightly), TensorFlow
# can be finicky with regard to resetting, if you wish to plot different model
# instances during the same IPython session (which is useful for exploring
# different snapshots of the model).
#
# So let's reset everything:
plt.close(1)
importlib.reload(main) # re-instantiate CVAE
tf.keras.backend.clear_session() # clean up dangling tensors
gc.collect() # and make sure they are gone
# Optionally, take a snapshot ID from the command line.
# E.g. "0153" for epoch 153, or "final" for the final after-training state.
if len(sys.argv) > 1:
snapshot = sys.argv[1]
else:
snapshot = "final"
# Load a model snapshot:
main.model = tf.keras.models.load_model(f"{config.vae_output_dir}model/{snapshot}.keras") # or whatever
# main.model.my_load(f"{config.vae_output_dir}model/final") # to load a snapshot produced by the legacy custom saver
# plt.ion() # interactive mode doesn't seem to work well with our heavily customized overlay plot
if main.model.latent_dim == 2:
latent_image = plotter.plot_latent_image(21, model=main.model)
plotter.overlay_datapoints(train_images, train_labels, latent_image)
else:
# n = 4000 # in practice fine
n = test_images.shape[0] # using all of the data is VERY slow (~10 minutes), but gives the best view.
plotter.plot_manifold(test_images[:n, :, :, :],
test_labels[:n],
model=main.model,
methods="all")
fig = plt.figure(1)
fig.savefig(f"{config.vae_output_dir}{config.overlay_fig_basename}_{snapshot}_from_test_script.{config.fig_format}")
fig.canvas.draw_idle() # see source of `plt.savefig`; need this if 'transparent=True' to reset colors
plt.show()