forked from hardmaru/sketch-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sample.py
194 lines (161 loc) · 6.48 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import numpy as np
import tensorflow as tf
import time
import os
import cPickle
import argparse
from utils import *
from model import Model
import random
import svgwrite
from IPython.display import SVG, display
# main code (not in a main function since I want to run this script in IPython as well).
def in_ipython():
try:
__IPYTHON__
except NameError:
return False
else:
return True
parser = argparse.ArgumentParser()
parser.add_argument('--real', type=int, default=0,
help='real, 1 if you want real sketches, or 0 for fake')
parser.add_argument('--filename', type=str, default='output/fake/fake',
help='filename of .svg file to output, without .svg')
parser.add_argument('--sample_length', type=int, default=600,
help='number of strokes to sample')
parser.add_argument('--picture_size', type=float, default=109,
help='a centered svg will be generated of this size')
parser.add_argument('--scale_factor', type=float, default=1,
help='factor to scale down by for svg output. smaller means bigger output')
parser.add_argument('--num_picture', type=int, default=20,
help='number of pictures to generate')
parser.add_argument('--num_col', type=int, default=1,
help='if num_picture > 1, how many pictures per row?')
parser.add_argument('--dataset_name', type=str, default="kanji",
help='name of directory containing training data')
parser.add_argument('--color_mode', type=int, default=0,
help='set to 0 if you are a black and white sort of person...')
parser.add_argument('--stroke_width', type=float, default=2.0,
help='thickness of pen lines')
parser.add_argument('--temperature', type=float, default=0.1,
help='sampling temperature')
sample_args = parser.parse_args()
color_mode = True
if sample_args.color_mode == 0:
color_mode = False
with open(os.path.join('save', sample_args.dataset_name, 'config.pkl')) as f: # future
saved_args = cPickle.load(f)
model = Model(saved_args, True)
sess = tf.InteractiveSession()
saver = tf.train.Saver(tf.all_variables())
ckpt = tf.train.get_checkpoint_state(
os.path.join('save', sample_args.dataset_name))
print "loading model: ", ckpt.model_checkpoint_path
saver.restore(sess, ckpt.model_checkpoint_path)
def draw_strokes_args(data, count=0):
draw_strokes(data, svg_filename='%s_%s.svg' % (
sample_args.filename, str(count).zfill(4)), stroke_width=sample_args.stroke_width, block_size=sample_args.picture_size,)
def sample_sketches(min_size_ratio=0.0, max_size_ratio=1, min_num_stroke=4, max_num_stroke=22, block_size=200, svg_only=True):
N = sample_args.num_picture
frame_size = float(sample_args.picture_size)
max_size = frame_size * max_size_ratio
min_size = frame_size * min_size_ratio
count = 0
sketch_list = []
param_list = []
temp_mixture = sample_args.temperature
temp_pen = sample_args.temperature
while count < N:
# print "attempting to generate picture #", count
print '.',
[strokes, params] = model.sample(
sess, sample_args.sample_length, temp_mixture, temp_pen, stop_if_eoc=True)
[sx, sy, num_stroke, num_char, _] = strokes.sum(0)
if num_stroke < min_num_stroke or num_char == 0 or num_stroke > max_num_stroke:
# print "num_stroke ", num_stroke, " num_char ", num_char
continue
[sx, sy, sizex, sizey] = calculate_start_point(strokes)
if sizex > max_size or sizey > max_size:
# print "sizex ", sizex, " sizey ", sizey
continue
if sizex < min_size or sizey < min_size:
# print "sizex ", sizex, " sizey ", sizey
continue
# success
draw_strokes_args(strokes, count=count)
print count + 1, "/", N
count += 1
# sketch_list = [strokes]
# param_list.append(params)
# draw_sketch_array(sketch_list, count=count, svg_only=svg_only)
# draw the pics
# draw_sketch_array(sketch_list, svg_only=svg_only)
return sketch_list, param_list
def get_bounds(data, block_size=200):
min_x = 0
max_x = 0
min_y = 0
max_y = 0
abs_x = 0
abs_y = 0
for i in xrange(len(data)):
x = float(data[i, 0])
y = float(data[i, 1])
abs_x += x
abs_y += y
min_x = min(min_x, abs_x)
min_y = min(min_y, abs_y)
max_x = max(max_x, abs_x)
max_y = max(max_y, abs_y)
abs_x = max_x - min_x
abs_y = max_y - min_y
padding = 10
factori = np.ceil(padding + max(abs_x, abs_y)) / block_size
return (min_x / factori, max_x, min_y / factori, max_y, padding, factori)
# little function that displays vector images and saves them to .svg
def draw_strokes(data, svg_filename='sample.svg', stroke_width=3, block_size=200):
min_x, max_x, min_y, max_y, padding, factori = get_bounds(
data, block_size=block_size)
dims = (block_size, block_size)
dwg = svgwrite.Drawing(svg_filename, size=dims)
dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))
abs_x = padding / 2 - min_x
abs_y = padding / 2 - min_y
lift_pen = 1
p = "M%s,%s " % (abs_x, abs_y)
command = "m"
for i in xrange(len(data)):
if (lift_pen == 1):
command = "m"
elif (command != "l"):
command = "l"
else:
command = ""
x = int(float(data[i, 0]) / factori)
y = int(float(data[i, 1]) / factori)
lift_pen = data[i, 2]
p += command + str(x) + "," + str(y) + " "
the_color = "black"
dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill('none'))
dwg.save()
# display(SVG(dwg.tostring()))
def do_real():
# kanji files downloaded from https://github.com/hardmaru/sketch-rnn-datasets/tree/master/kanji
filename = "kanji/short_kanji.npz"
load_data = np.load(filename)
train_set = load_data['train']
for i in xrange(sample_args.num_picture):
draw_strokes(train_set[i], stroke_width=2, block_size=109, svg_filename='%s_%s.svg' % (
'output/real/real', str(i).zfill(4)))
print("{}/{}".format(i + 1, sample_args.num_picture))
if __name__ == '__main__':
ipython_mode = in_ipython()
if ipython_mode:
print "IPython detected"
else:
print "Console mode"
if sample_args.real:
do_real()
else:
[strokes, params] = sample_sketches(svg_only=not ipython_mode)