-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
115 lines (103 loc) · 3.48 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/env python
import json
import os
import numpy as np
import tensorflow as tf
import model, sample, encoder
from urllib.parse import urlparse, parse_qsl
from http.server import BaseHTTPRequestHandler, HTTPServer
def get_text(
text,
length=None,
temperature=1,
top_k=40,
model_name='117M',
seed=None,
nsamples=1,
batch_size=1
):
print('Generating text for input "%s"...' % text)
if batch_size is None:
batch_size = 1
assert nsamples % batch_size == 0
print(model_name)
enc = encoder.get_encoder(model_name)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
if length is None:
length = hparams.n_ctx // 2
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
np.random.seed(seed)
tf.compat.v1.set_random_seed(seed)
output = sample.sample_sequence(
hparams=hparams, length=length,
context=context,
batch_size=batch_size,
temperature=temperature, top_k=top_k
)
saver = tf.compat.v1.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
saver.restore(sess, ckpt)
context_tokens = enc.encode(text)
generated = 0
result = []
for _ in range(nsamples // batch_size):
out = sess.run(output, feed_dict={
context: [context_tokens for _ in range(batch_size)]
})[:, len(context_tokens):]
for i in range(batch_size):
generated += 1
text = enc.decode(out[i])
result.append(text)
return result
def get_clean_output(output):
if (len(output) > 0):
result = output[0]
result = result.split('<|endoftext|>')[0]
return result
else:
return None
class server(BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header('Content-type','application/json')
self.end_headers()
query_str = urlparse(self.path).query
query = dict(parse_qsl(query_str))
data = {}
length = None
top_k = 40
temperature = 1
if 'length' in query:
length = int(query['length'])
data['length'] = length
if 'top_k' in query:
top_k = int(query['top_k'])
data['top_k'] = top_k
if 'temperature' in query:
temperature = float(query['temperature'])
data['temperature'] = temperature
if 'q' in query:
q = query['q']
print('Query: "%s" length: %s top_k: %d temperature: %f' % (q, length, top_k, temperature))
output = get_text(q, length, temperature, top_k)
print('Generated %s' % output)
output = get_clean_output(output)
data['q'] = q
data['output'] = q + output
else:
data = {}
json_string = json.dumps(data)
self.wfile.write(bytes(json_string, 'utf-8'))
return
def run():
port = int(os.environ.get('PORT', 8080))
server_address = ('', port)
httpd = HTTPServer(server_address, server)
print('Running server on port %d...' % port)
httpd.serve_forever()
run()