-
Notifications
You must be signed in to change notification settings - Fork 166
/
generator.py
31 lines (24 loc) · 942 Bytes
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
from scipy.misc import imsave
class Generator(object):
def update_params(self, input_tensor):
'''Update parameters of the network
Args:
input_tensor: a batch of flattened images
Returns:
Current loss value
'''
raise NotImplementedError()
def generate_and_save_images(self, num_samples, directory):
'''Generates the images using the model and saves them in the directory
Args:
num_samples: number of samples to generate
directory: a directory to save the images
'''
imgs = self.sess.run(self.sampled_tensor)
for k in range(imgs.shape[0]):
imgs_folder = os.path.join(directory, 'imgs')
if not os.path.exists(imgs_folder):
os.makedirs(imgs_folder)
imsave(os.path.join(imgs_folder, '%d.png') % k,
imgs[k].reshape(28, 28))