Skip to content

Commit

Permalink
Moved api out to server
Browse files Browse the repository at this point in the history
Reworked sockets to use socketio
Added progress to nodes
Added highlight to active node
Added preview to saveimage node
  • Loading branch information
pythongosssss authored Feb 21, 2023
1 parent 9280871 commit a52aa9f
Show file tree
Hide file tree
Showing 6 changed files with 393 additions and 236 deletions.
247 changes: 60 additions & 187 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,7 @@
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

try:
import aiohttp
from aiohttp import web
except ImportError:
print("Module 'aiohttp' not installed. Please install it via:")
print("pip install aiohttp")
print("or")
print("pip install -r requirements.txt")
sys.exit()
import server

if __name__ == "__main__":
if '--help' in sys.argv:
Expand All @@ -36,14 +28,14 @@
print()
exit()

if '--dont-upcast-attention' in sys.argv:
print("disabling upcasting of attention")
os.environ['ATTN_PRECISION'] = "fp16"
if '--dont-upcast-attention' in sys.argv:
print("disabling upcasting of attention")
os.environ['ATTN_PRECISION'] = "fp16"

import torch
import nodes

def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}):
def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}, server=None, unique_id=None):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
for x in inputs:
Expand All @@ -65,9 +57,13 @@ def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}):
if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data:
input_data_all[x] = extra_data['extra_pnginfo']
if h[x] == "SERVER":
input_data_all[x] = server
if h[x] == "UNIQUE_ID":
input_data_all[x] = unique_id
return input_data_all

def recursive_execute(prompt, outputs, current_item, extra_data={}):
def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
Expand All @@ -84,9 +80,11 @@ def recursive_execute(prompt, outputs, current_item, extra_data={}):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
executed += recursive_execute(prompt, outputs, input_unique_id, extra_data)
executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data)

input_data_all = get_input_data(inputs, class_def, outputs, prompt, extra_data)
input_data_all = get_input_data(inputs, class_def, outputs, prompt, extra_data, server, unique_id)
if server.client_id is not None:
server.send_sync("execute", { "node": unique_id }, server.client_id)
obj = class_def()

outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
Expand Down Expand Up @@ -157,11 +155,17 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
return to_delete

class PromptExecutor:
def __init__(self):
def __init__(self, server):
self.outputs = {}
self.old_prompt = {}
self.server = server

def execute(self, prompt, extra_data={}):
if "client_id" in extra_data:
self.server.client_id = extra_data["client_id"]
else:
self.server.client_id = None

with torch.no_grad():
for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
Expand Down Expand Up @@ -190,7 +194,7 @@ def execute(self, prompt, extra_data={}):
except:
valid = False
if valid:
executed += recursive_execute(prompt, self.outputs, x, extra_data)
executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data)

except Exception as e:
print(traceback.format_exc())
Expand All @@ -208,6 +212,11 @@ def execute(self, prompt, extra_data={}):
executed = set(executed)
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])

finally:
if self.server.client_id is not None:
self.server.send_sync("execute", { "node": None }, self.server.client_id)

torch.cuda.empty_cache()

def validate_inputs(prompt, item):
Expand Down Expand Up @@ -293,27 +302,27 @@ def validate_prompt(prompt):

return (True, "")

def prompt_worker(q):
e = PromptExecutor()
def prompt_worker(q, server):
e = PromptExecutor(server)
while True:
item, item_id = q.get()
e.execute(item[-2], item[-1])
q.task_done(item_id)

class PromptQueue:
def __init__(self, socket_handler):
self.socket_handler = socket_handler
def __init__(self, server):
self.server = server
self.mutex = threading.RLock()
self.not_empty = threading.Condition(self.mutex)
self.task_counter = 0
self.queue = []
self.currently_running = {}
socket_handler.prompt_queue = self
server.prompt_queue = self

def put(self, item):
with self.mutex:
heapq.heappush(self.queue, item)
self.socket_handler.queue_updated(self)
self.server.queue_updated()
self.not_empty.notify()

def get(self):
Expand All @@ -324,13 +333,13 @@ def get(self):
i = self.task_counter
self.currently_running[i] = copy.deepcopy(item)
self.task_counter += 1
self.socket_handler.queue_updated(self)
self.server.queue_updated()
return (item, i)

def task_done(self, item_id):
with self.mutex:
self.currently_running.pop(item_id)
self.socket_handler.queue_updated(self)
self.server.queue_updated()

def get_current_queue(self):
with self.mutex:
Expand All @@ -346,7 +355,7 @@ def get_tasks_remaining(self):
def wipe_queue(self):
with self.mutex:
self.queue = []
self.socket_handler.queue_updated(self)
self.server.queue_updated()

def delete_queue_item(self, function):
with self.mutex:
Expand All @@ -357,174 +366,32 @@ def delete_queue_item(self, function):
else:
self.queue.pop(x)
heapq.heapify(self.queue)
self.socket_handler.queue_updated(self)
self.server.queue_updated()
return True
return False

def get_queue_info(prompt_queue):
prompt_info = {}
exec_info = {}
exec_info['queue_remaining'] = prompt_queue.get_tasks_remaining()
prompt_info['exec_info'] = exec_info
return prompt_info

class SocketHandler():
def __init__(self, loop):
self.connected = set()
self.messages = asyncio.Queue()
self.loop = loop

async def publish_loop(self):
while True:
msg = await self.messages.get()
await self.send(msg)

def queue_updated(self, queue):
# This is called by the queue processing thread so we need to make it thread safe
loop.call_soon_threadsafe(self.messages.put_nowait, { 'type': 'status', 'status': get_queue_info(queue) })

async def send(self, message, socket = None):
if isinstance(message, str) == False:
message = json.dumps(message)

if socket is None:
for ws in self.connected:
await ws.send_str(message)
else:
await socket.send_str(message)

async def process(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
self.connected.add(ws)
try:
# Send initial state to the new client
await self.send({ 'type': 'status', 'status': get_queue_info(self.prompt_queue) }, ws)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception())
finally:
self.connected.remove(ws)

return ws

class PromptServer():
def __init__(self, prompt_queue, socket_handler):
self.prompt_queue = prompt_queue
self.socket_handler = socket_handler
self.number = 0
self.app = web.Application()
self.web_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "webshit")
routes = web.RouteTableDef()

@routes.get('/ws')
async def websocket_handler(request):
return await self.socket_handler.process(request)

@routes.get("/")
async def get_root(request):
return web.FileResponse(os.path.join(self.web_root, "index.html"))

@routes.get("/prompt")
async def get_prompt(request):
return web.json_response(get_queue_info(self.prompt_queue))

@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['name'] = x #TODO
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
out[x] = info
return web.json_response(out)

@routes.get("/queue")
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info)

@routes.post("/prompt")
async def post_prompt(request):
print("got prompt")
resp_code = 200
out_string = ""
json_data = await request.json()

if "number" in json_data:
number = float(json_data['number'])
else:
number = self.number
if "front" in json_data:
if json_data['front']:
number = -number

self.number += 1
if "prompt" in json_data:
prompt = json_data["prompt"]
valid = validate_prompt(prompt)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
if valid[0]:
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
else:
resp_code = 400
out_string = valid[1]
print("invalid prompt:", valid[1])

return web.Response(body=out_string, status=resp_code)

@routes.post("/queue")
async def post_queue(request):
json_data = await request.json()
if "clear" in json_data:
if json_data["clear"]:
self.prompt_queue.wipe_queue()
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete)
self.prompt_queue.delete_queue_item(delete_func)

return web.Response(status=200)

self.app.add_routes(routes)
self.app.add_routes([
web.static('/', self.web_root),
])

async def start_server(server, address, port):
runner = web.AppRunner(server.app)
await runner.setup()
site = web.TCPSite(runner, address, port)
await site.start()

if address == '':
address = '0.0.0.0'
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port))
async def run(server, address='', port=8188):
await asyncio.gather(server.start(address, port), server.publish_loop())

async def run(prompt_queue, socket_handler, address='', port=8188):
server = PromptServer(prompt_queue, socket_handler)
await asyncio.gather(start_server(server, address, port), socket_handler.publish_loop())
def hijack_progress(server):
from tqdm.auto import tqdm
orig_func = getattr(tqdm, "update")
def wrapped_func(*args, **kwargs):
pbar = args[0]
v = orig_func(*args, **kwargs)
server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id)
return v
setattr(tqdm, "update", wrapped_func)

if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server = server.PromptServer(loop)
q = PromptQueue(server)

socket_handler = SocketHandler(loop)
q = PromptQueue(socket_handler)
threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start()
hijack_progress(server)

threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
if '--listen' in sys.argv:
address = '0.0.0.0'
else:
Expand All @@ -537,5 +404,11 @@ async def run(prompt_queue, socket_handler, address='', port=8188):
except:
pass

loop.run_until_complete(run(q, socket_handler, address=address, port=port))
if os.name == "nt":
try:
loop.run_until_complete(run(server, address=address, port=port))
except KeyboardInterrupt:
pass
else:
loop.run_until_complete(run(server, address=address, port=port))

Loading

0 comments on commit a52aa9f

Please sign in to comment.