-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcodebotler.py
executable file
·213 lines (191 loc) · 7.54 KB
/
codebotler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
#! /usr/bin/env python3
import os
import threading
import http.server
import socketserver
import asyncio
import websockets
import json
import signal
import time
import sys
from models.model_factory import load_model
from models.OpenAIChatModel import OpenAIChatModel
import threading
ros_available = False
robot_available = False
robot_interface = None
try:
import rospy
ros_available = True
rospy.init_node('ros_interface', anonymous=False)
except:
print("Could not import rospy. Robot interface is not available.")
ros_available = False
httpd = None
server_thread = None
model = None
asyncio_loop = None
ws_server = None
prompt_prefix = ""
prompt_suffix = ""
def serve_interface_html(args):
global httpd
class HTMLFileHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header('Content-type', 'text/html')
self.end_headers()
with open(args.interface_page, 'r') as file:
html = file.read()
html = html.replace("ws://localhost:8190",
f"ws://{args.ip}:{args.ws_port}")
self.wfile.write(bytes(html, 'utf8'))
print(f"Starting server at http://{args.ip}:{args.port}")
try:
httpd = http.server.HTTPServer((args.ip, args.port), HTMLFileHandler)
httpd.serve_forever()
except Exception as e:
print("HTTP server error: " + str(e))
shutdown(None, None)
def generate_code(prompt, args):
global model, prompt_prefix, prompt_suffix, code_timeout
start_time = time.time()
stop_sequences = ["\n#", "\nclass", "```"]
if args.model_type != "openai-chat":
prompt = prompt_prefix + prompt + prompt_suffix
stop_sequences += ["\ndef"]
code = model.generate_one(prompt=prompt,
stop_sequences=stop_sequences,
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens)
end_time = time.time()
print(f"Code generation time: {round(end_time - start_time, 2)} seconds")
if type(model) is not OpenAIChatModel:
code = (prompt_suffix + code).strip()
elif not code.startswith(prompt_suffix.strip()):
code = (prompt_suffix + "\n" + code).strip()
return code
def execute(code):
global ros_available
global robot_available
global robot_interface
if not ros_available:
print("ROS not available. Ignoring execute request.")
elif not robot_available:
print("Robot not available. Ignoring execute request.")
else:
from robot_interface.src.robot_client_interface import execute_task_program
robot_execution_thread = threading.Thread(target=execute_task_program, name="robot_execute", args=[code, robot_interface])
robot_execution_thread.start()
async def handle_message(websocket, message, args):
data = json.loads(message)
if data['type'] == 'code':
print("Received code generation request")
code = generate_code(data['prompt'], args)
response = {"code": f"{code}"}
await websocket.send(json.dumps(response))
if data['execute']:
print("Executing generated code")
execute(code)
elif data['type'] == 'eval':
print("Received eval request")
# await eval(websocket, data)
elif data['type'] == 'execute':
print("Executing generated code")
execute(data['code'])
await websocket.close()
else:
print("Unknown message type: " + data['type'])
async def ws_main(websocket, path, args):
try:
async for message in websocket:
await handle_message(websocket, message, args)
except websockets.exceptions.ConnectionClosed:
pass
def start_completion_callback(args):
global asyncio_loop, ws_server
# Create an asyncio event loop
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
start_server = websockets.serve(lambda ws, path: ws_main(ws, path, args), args.ip, args.ws_port)
try:
ws_server = asyncio_loop.run_until_complete(start_server)
asyncio_loop.run_forever()
except Exception as e:
print("Websocket error: " + str(e))
shutdown(None, None)
def shutdown(sig, frame):
global ros_available, robot_available, robot_interface, server_thread, asyncio_loop, httpd, ws_server
print(" Shutting down server.")
if robot_available and ros_available and robot_interface is not None:
robot_interface._cancel_goals()
print("Waiting for 2s to preempt robot actions...")
time.sleep(2)
if ros_available:
rospy.signal_shutdown("Shutting down Server")
if httpd is not None:
httpd.server_close()
httpd.shutdown()
if server_thread is not None and threading.current_thread() != server_thread:
server_thread.join()
if asyncio_loop is not None:
for task in asyncio.all_tasks(loop=asyncio_loop):
task.cancel()
asyncio_loop.stop()
if ws_server is not None:
ws_server.close()
if sig == signal.SIGINT or sig == signal.SIGTERM:
exit_code = 0
else:
exit_code = 1
sys.exit(exit_code)
def main():
global server_thread
global prompt_prefix
global prompt_suffix
global ros_available
global robot_available
global robot_interface
global code_timeout
global model
import argparse
from pathlib import Path
parser = argparse.ArgumentParser()
parser.add_argument('--ip', type=str, help='IP address', default="localhost")
parser.add_argument('--port', type=int, help='HTML server port number', default=8080)
parser.add_argument('--ws-port', type=int, help='Websocket server port number', default=8190)
parser.add_argument("--model-type", choices=["openai", "openai-chat", "palm", "automodel", "hf-textgen"], default="openai-chat")
parser.add_argument('--model-name', type=str, help='Model name', default='gpt-4')
parser.add_argument('--tgi-server-url', type=str, help='Text Generation Inference Client URL', default='http://127.0.0.1:8082')
parser.add_argument('--chat-prompt-prefix', type=Path, help='Prompt prefix for GPT chat completion only', default='code_generation/openai_chat_completion_prefix.py')
parser.add_argument('--prompt-prefix', type=Path, help='Prompt prefix for all but GPT chat completion', default='code_generation/prompt_prefix.py')
parser.add_argument('--prompt-suffix', type=Path, help='Prompt suffix for all but GPT chat completion', default='code_generation/prompt_suffix.py')
parser.add_argument('--interface-page', type=Path, help='Interface page', default='code_generation/interface.html')
parser.add_argument('--max-workers', type=int, help='Maximum number of workers', default=1)
parser.add_argument("--max-tokens", type=int, default=512)
parser.add_argument("--top-p", type=float, default=0.95)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument('--robot', action='store_true', help='Flag to indicate if the robot is available')
parser.add_argument('--timeout', type=int, help='Code generation timeout in seconds', default=20)
if ros_available:
args = parser.parse_args(rospy.myargv()[1:])
else:
args = parser.parse_args()
robot_available = args.robot
code_timeout = args.timeout
signal.signal(signal.SIGINT, shutdown)
if robot_available and ros_available:
from robot_interface.src.robot_client_interface import RobotInterface
robot_interface = RobotInterface()
prompt_prefix = args.prompt_prefix.read_text()
prompt_suffix = args.prompt_suffix.read_text()
model = load_model(args)
server_thread = threading.Thread(target=serve_interface_html,
name="HTTP server thread",
args=[args])
server_thread.start()
start_completion_callback(args)
if __name__ == "__main__":
main()