forked from jordip/prompt-generator-api
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
157 lines (136 loc) · 4.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
Provides a simple Python API to generate prompts for AI image generation
"""
import os
import re
import logging
import json
import uuid
from flask import Flask
from flask_restful import Resource, Api, reqparse, abort
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from transformers import GPT2Tokenizer, GPT2LMHeadModel
# default config
save_to_file = False
default_args = {
'temperature': {
'type': float,
'default': 0.9,
'range': [0, 1]
},
'top_k': {
'type': int,
'default': 8,
'range': [1, 200]
},
'max_length': {
'type': int,
'default': 80,
'range': [1, 200]
},
'repetition_penalty': {
'type': float,
'default': 1.2,
'range': [0, 10]
},
'num_return_sequences': {
'type': int,
'default': 5,
'range': [1, 5]
},
}
app = Flask("PromptGeneratorAPI")
api = Api(app)
parser = reqparse.RequestParser()
parser.add_argument('prompt', required=True)
# optional arguments
for arg, def_arg in default_args.items():
parser.add_argument(arg, type=def_arg['type'], required=False)
# limit the amount of requests per user
limiter = Limiter(
get_remote_address,
app=app,
default_limits=["20 per minute"]
)
class PromptGenerator(Resource):
"""Prompt Generator Class
Args:
Resource Resource: Flask restful resource
"""
def validate_args(self, args):
"""Validate range, set dynamic variables value
Args:
args dict: Arguments provided in the request
"""
for arg, def_arg in default_args.items():
if arg in args and args[arg]:
if def_arg['range'][0] < args[arg] > def_arg['range'][1]:
abort(500,
message=f"{arg} out of range. Min {def_arg['range'][0]}, Max {def_arg['range'][1]}")
globals()[arg] = args[arg]
else:
globals()[arg] = def_arg['default']
def get_blacklist(self):
"""Check and load blacklist
Returns:
list: List of terms from the blacklist dictionary
"""
blacklist_filename = 'blacklist.txt'
blacklist = []
if not os.path.exists(blacklist_filename):
logging.warning("Blacklist file missing: %s", blacklist_filename)
return blacklist
with open(blacklist_filename, 'r') as f:
for line in f:
blacklist.append(line)
return blacklist
def post(self):
"""Post method
Returns:
string: JSON list with the generated prompts
"""
args = parser.parse_args()
self.validate_args(args)
prompt = args['prompt']
request_uuid = uuid.uuid4()
try:
# build model
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2')
except Exception as e:
logging.error(
"Exception encountered while attempting to install tokenizer: %s", e)
abort(500, message="There was an error processing your request")
try:
# generate prompt
logging.debug("Generate new prompt from: \"%s\"", prompt)
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
output = model.generate(input_ids, do_sample=True, temperature=temperature,
top_k=top_k, max_length=max_length,
num_return_sequences=num_return_sequences,
repetition_penalty=repetition_penalty,
penalty_alpha=0.6, no_repeat_ngram_size=1,
early_stopping=True)
prompt_output = []
blacklist = self.get_blacklist()
for count, value in enumerate(output):
prompt_output.append(
tokenizer.decode(value, skip_special_tokens=True)
)
for term in blacklist:
prompt_output[count] = re.sub(
term, "", prompt_output[count], flags=re.IGNORECASE)
# save results to file
if save_to_file:
with open(f"{request_uuid}.json", 'w') as f:
json.dump(prompt_output, f)
return prompt_output
except Exception as e:
logging.error(
"Exception encountered while attempting to generate prompt: %s", e)
abort(500, message="There was an error processing your request")
api.add_resource(PromptGenerator, '/generate')
if __name__ == '__main__':
app.run(debug=False)