forked from Samagra-Development/ai-tools
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapi.py
138 lines (103 loc) · 4.02 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import importlib
from quart import Quart, g, request, jsonify, abort, Response
from markupsafe import escape
import json
import aiohttp
import os
from functools import wraps
from dotenv import load_dotenv
from quart_compress import Compress
import time
from watch_folder import watch_src_folder
import asyncio
load_dotenv()
app = Quart(__name__)
Compress(app)
watch_dir = "src"
debug_mode = app.debug
async def create_client_session():
"""
Create an aiohttp ClientSession.
"""
async with aiohttp.ClientSession() as session:
return session
extra_dirs = ['src']
extra_files = extra_dirs[:]
for extra_dir in extra_dirs:
for dirname, dirs, files in os.walk(extra_dir):
for filename in files:
filename = os.path.join(dirname, filename)
if os.path.isfile(filename):
extra_files.append(filename)
app.config.update(extra_files=extra_files)
with open('repository_data.json') as f:
repository_data = json.load(f)
AUTH_HEADER = os.getenv("AUTH_HEADER")
AUTH_HEADER_KEY = os.getenv("AUTH_HEADER_KEY")
def verify_auth_header(auth_header_key, expected_value):
def decorator(f):
@wraps(f)
async def decorated_function(*args, **kwargs):
auth_header = request.headers.get(auth_header_key)
if not auth_header or auth_header != expected_value:
print("Unauthorized access");
abort(401) # Unauthorized
return await f(*args, **kwargs)
return decorated_function
return decorator
@app.route("/")
def welcome():
return "<p>Welcome!</p>"
@app.route("/repository")
def repository():
""" Returns the repository data, which contains the available models and their configurations"""
return jsonify(repository_data)
def json_to_object(request_class, json_str):
"""Converts a JSON string to an object of the given class at level 1."""
data = json.loads(json_str)
return request_class(**data)
def get_model_config(use_case, provider, mode):
""" Returns the model config for the given use case, provider and mode """
use_case_data = repository_data.get('use_cases').get(use_case)
if use_case_data is None:
return f'{escape(use_case)} Use case is not available', 400
provider_data = use_case_data.get(provider)
if provider_data is None:
return f'{escape(provider)} Provider is not available', 400
mode = provider_data.get(mode)
if mode is None:
return f'{escape(mode)} Mode is not available', 400
return mode, 200
@app.route("/<use_case>/<provider>/<mode>", methods=['POST'])
@verify_auth_header(AUTH_HEADER_KEY, AUTH_HEADER)
async def transformer(use_case, provider, mode):
""" Returns the translation for the given tex; provider and mode are as mentioned in the repository"""
start_time = time.time()
model_config = get_model_config(use_case, provider, mode)
if model_config[1] != 200:
return model_config
model_class_name = model_config[0].get('model_class')
model_request_class_name = model_config[0].get('request_class')
module = importlib.import_module("src" + "." + use_case + "." + provider + "." + mode)
model = getattr(module, model_class_name)(app)
model_request = getattr(module, model_request_class_name)
request_class = json_to_object(model_request, json.dumps(await request.json))
if model_config[0].get("__is_async"):
response = await model.inference(request_class)
else:
response = model.inference(request_class)
end_time = time.time()
response_time_ms = int((end_time - start_time) * 1000)
headers = {"ai-tools-response-time": str(response_time_ms)}
return Response(response, headers=headers)
@app.before_serving
async def startup():
"""
Startup function called before serving requests.
"""
app.client = await create_client_session()
# monitor src if in debug mode
if debug_mode:
asyncio.get_event_loop().create_task(watch_src_folder(app, watch_dir))
# quart --app api --debug run
# hypercorn api -b 0.0.0.0:8000