From a52aa9f4b5eabf91204c0cfd4c27685c3f1be7ea Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 21 Feb 2023 19:29:49 +0000 Subject: [PATCH] Moved api out to server Reworked sockets to use socketio Added progress to nodes Added highlight to active node Added preview to saveimage node --- main.py | 247 ++++++++++----------------------------- nodes.py | 12 +- requirements.txt | 2 +- server.py | 173 +++++++++++++++++++++++++++ webshit/index.html | 188 ++++++++++++++++++++++------- webshit/socket.io.min.js | 7 ++ 6 files changed, 393 insertions(+), 236 deletions(-) create mode 100644 server.py create mode 100644 webshit/socket.io.min.js diff --git a/main.py b/main.py index 7c72bc4e0e8..283b0bd2667 100644 --- a/main.py +++ b/main.py @@ -11,15 +11,7 @@ import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) -try: - import aiohttp - from aiohttp import web -except ImportError: - print("Module 'aiohttp' not installed. Please install it via:") - print("pip install aiohttp") - print("or") - print("pip install -r requirements.txt") - sys.exit() +import server if __name__ == "__main__": if '--help' in sys.argv: @@ -36,14 +28,14 @@ print() exit() -if '--dont-upcast-attention' in sys.argv: - print("disabling upcasting of attention") - os.environ['ATTN_PRECISION'] = "fp16" + if '--dont-upcast-attention' in sys.argv: + print("disabling upcasting of attention") + os.environ['ATTN_PRECISION'] = "fp16" import torch import nodes -def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}): +def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}, server=None, unique_id=None): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} for x in inputs: @@ -65,9 +57,13 @@ def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}): if h[x] == "EXTRA_PNGINFO": if "extra_pnginfo" in extra_data: input_data_all[x] = extra_data['extra_pnginfo'] + if h[x] == "SERVER": + input_data_all[x] = server + if h[x] == "UNIQUE_ID": + input_data_all[x] = unique_id return input_data_all -def recursive_execute(prompt, outputs, current_item, extra_data={}): +def recursive_execute(server, prompt, outputs, current_item, extra_data={}): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] @@ -84,9 +80,11 @@ def recursive_execute(prompt, outputs, current_item, extra_data={}): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - executed += recursive_execute(prompt, outputs, input_unique_id, extra_data) + executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data) - input_data_all = get_input_data(inputs, class_def, outputs, prompt, extra_data) + input_data_all = get_input_data(inputs, class_def, outputs, prompt, extra_data, server, unique_id) + if server.client_id is not None: + server.send_sync("execute", { "node": unique_id }, server.client_id) obj = class_def() outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) @@ -157,11 +155,17 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item return to_delete class PromptExecutor: - def __init__(self): + def __init__(self, server): self.outputs = {} self.old_prompt = {} + self.server = server def execute(self, prompt, extra_data={}): + if "client_id" in extra_data: + self.server.client_id = extra_data["client_id"] + else: + self.server.client_id = None + with torch.no_grad(): for x in prompt: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) @@ -190,7 +194,7 @@ def execute(self, prompt, extra_data={}): except: valid = False if valid: - executed += recursive_execute(prompt, self.outputs, x, extra_data) + executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data) except Exception as e: print(traceback.format_exc()) @@ -208,6 +212,11 @@ def execute(self, prompt, extra_data={}): executed = set(executed) for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) + + finally: + if self.server.client_id is not None: + self.server.send_sync("execute", { "node": None }, self.server.client_id) + torch.cuda.empty_cache() def validate_inputs(prompt, item): @@ -293,27 +302,27 @@ def validate_prompt(prompt): return (True, "") -def prompt_worker(q): - e = PromptExecutor() +def prompt_worker(q, server): + e = PromptExecutor(server) while True: item, item_id = q.get() e.execute(item[-2], item[-1]) q.task_done(item_id) class PromptQueue: - def __init__(self, socket_handler): - self.socket_handler = socket_handler + def __init__(self, server): + self.server = server self.mutex = threading.RLock() self.not_empty = threading.Condition(self.mutex) self.task_counter = 0 self.queue = [] self.currently_running = {} - socket_handler.prompt_queue = self + server.prompt_queue = self def put(self, item): with self.mutex: heapq.heappush(self.queue, item) - self.socket_handler.queue_updated(self) + self.server.queue_updated() self.not_empty.notify() def get(self): @@ -324,13 +333,13 @@ def get(self): i = self.task_counter self.currently_running[i] = copy.deepcopy(item) self.task_counter += 1 - self.socket_handler.queue_updated(self) + self.server.queue_updated() return (item, i) def task_done(self, item_id): with self.mutex: self.currently_running.pop(item_id) - self.socket_handler.queue_updated(self) + self.server.queue_updated() def get_current_queue(self): with self.mutex: @@ -346,7 +355,7 @@ def get_tasks_remaining(self): def wipe_queue(self): with self.mutex: self.queue = [] - self.socket_handler.queue_updated(self) + self.server.queue_updated() def delete_queue_item(self, function): with self.mutex: @@ -357,174 +366,32 @@ def delete_queue_item(self, function): else: self.queue.pop(x) heapq.heapify(self.queue) - self.socket_handler.queue_updated(self) + self.server.queue_updated() return True return False -def get_queue_info(prompt_queue): - prompt_info = {} - exec_info = {} - exec_info['queue_remaining'] = prompt_queue.get_tasks_remaining() - prompt_info['exec_info'] = exec_info - return prompt_info - -class SocketHandler(): - def __init__(self, loop): - self.connected = set() - self.messages = asyncio.Queue() - self.loop = loop - - async def publish_loop(self): - while True: - msg = await self.messages.get() - await self.send(msg) - - def queue_updated(self, queue): - # This is called by the queue processing thread so we need to make it thread safe - loop.call_soon_threadsafe(self.messages.put_nowait, { 'type': 'status', 'status': get_queue_info(queue) }) - - async def send(self, message, socket = None): - if isinstance(message, str) == False: - message = json.dumps(message) - - if socket is None: - for ws in self.connected: - await ws.send_str(message) - else: - await socket.send_str(message) - - async def process(self, request): - ws = web.WebSocketResponse() - await ws.prepare(request) - self.connected.add(ws) - try: - # Send initial state to the new client - await self.send({ 'type': 'status', 'status': get_queue_info(self.prompt_queue) }, ws) - async for msg in ws: - if msg.type == aiohttp.WSMsgType.ERROR: - print('ws connection closed with exception %s' % ws.exception()) - finally: - self.connected.remove(ws) - - return ws - -class PromptServer(): - def __init__(self, prompt_queue, socket_handler): - self.prompt_queue = prompt_queue - self.socket_handler = socket_handler - self.number = 0 - self.app = web.Application() - self.web_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "webshit") - routes = web.RouteTableDef() - - @routes.get('/ws') - async def websocket_handler(request): - return await self.socket_handler.process(request) - - @routes.get("/") - async def get_root(request): - return web.FileResponse(os.path.join(self.web_root, "index.html")) - - @routes.get("/prompt") - async def get_prompt(request): - return web.json_response(get_queue_info(self.prompt_queue)) - - @routes.get("/object_info") - async def get_object_info(request): - out = {} - for x in nodes.NODE_CLASS_MAPPINGS: - obj_class = nodes.NODE_CLASS_MAPPINGS[x] - info = {} - info['input'] = obj_class.INPUT_TYPES() - info['output'] = obj_class.RETURN_TYPES - info['name'] = x #TODO - info['description'] = '' - info['category'] = 'sd' - if hasattr(obj_class, 'CATEGORY'): - info['category'] = obj_class.CATEGORY - out[x] = info - return web.json_response(out) - - @routes.get("/queue") - async def get_queue(request): - queue_info = {} - current_queue = self.prompt_queue.get_current_queue() - queue_info['queue_running'] = current_queue[0] - queue_info['queue_pending'] = current_queue[1] - return web.json_response(queue_info) - - @routes.post("/prompt") - async def post_prompt(request): - print("got prompt") - resp_code = 200 - out_string = "" - json_data = await request.json() - - if "number" in json_data: - number = float(json_data['number']) - else: - number = self.number - if "front" in json_data: - if json_data['front']: - number = -number - - self.number += 1 - if "prompt" in json_data: - prompt = json_data["prompt"] - valid = validate_prompt(prompt) - extra_data = {} - if "extra_data" in json_data: - extra_data = json_data["extra_data"] - if valid[0]: - self.prompt_queue.put((number, id(prompt), prompt, extra_data)) - else: - resp_code = 400 - out_string = valid[1] - print("invalid prompt:", valid[1]) - - return web.Response(body=out_string, status=resp_code) - - @routes.post("/queue") - async def post_queue(request): - json_data = await request.json() - if "clear" in json_data: - if json_data["clear"]: - self.prompt_queue.wipe_queue() - if "delete" in json_data: - to_delete = json_data['delete'] - for id_to_delete in to_delete: - delete_func = lambda a: a[1] == int(id_to_delete) - self.prompt_queue.delete_queue_item(delete_func) - - return web.Response(status=200) - - self.app.add_routes(routes) - self.app.add_routes([ - web.static('/', self.web_root), - ]) - -async def start_server(server, address, port): - runner = web.AppRunner(server.app) - await runner.setup() - site = web.TCPSite(runner, address, port) - await site.start() - - if address == '': - address = '0.0.0.0' - print("Starting server\n") - print("To see the GUI go to: http://{}:{}".format(address, port)) +async def run(server, address='', port=8188): + await asyncio.gather(server.start(address, port), server.publish_loop()) -async def run(prompt_queue, socket_handler, address='', port=8188): - server = PromptServer(prompt_queue, socket_handler) - await asyncio.gather(start_server(server, address, port), socket_handler.publish_loop()) +def hijack_progress(server): + from tqdm.auto import tqdm + orig_func = getattr(tqdm, "update") + def wrapped_func(*args, **kwargs): + pbar = args[0] + v = orig_func(*args, **kwargs) + server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id) + return v + setattr(tqdm, "update", wrapped_func) if __name__ == "__main__": loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + server = server.PromptServer(loop) + q = PromptQueue(server) - socket_handler = SocketHandler(loop) - q = PromptQueue(socket_handler) - threading.Thread(target=prompt_worker, daemon=True, args=(q,)).start() + hijack_progress(server) + + threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() if '--listen' in sys.argv: address = '0.0.0.0' else: @@ -537,5 +404,11 @@ async def run(prompt_queue, socket_handler, address='', port=8188): except: pass - loop.run_until_complete(run(q, socket_handler, address=address, port=port)) + if os.name == "nt": + try: + loop.run_until_complete(run(server, address=address, port=port)) + except KeyboardInterrupt: + pass + else: + loop.run_until_complete(run(server, address=address, port=port)) diff --git a/nodes.py b/nodes.py index 3bdad71beb7..e307f6b8aff 100644 --- a/nodes.py +++ b/nodes.py @@ -605,7 +605,7 @@ def INPUT_TYPES(s): return {"required": {"images": ("IMAGE", ), "filename_prefix": ("STRING", {"default": "ComfyUI"})}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "server": "SERVER", "unique_id": "UNIQUE_ID"}, } RETURN_TYPES = () @@ -615,7 +615,7 @@ def INPUT_TYPES(s): CATEGORY = "image" - def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None, server=None, unique_id=None): def map_filename(filename): prefix_len = len(filename_prefix) prefix = filename[:prefix_len + 1] @@ -631,6 +631,8 @@ def map_filename(filename): except FileNotFoundError: os.mkdir(self.output_dir) counter = 1 + + paths = list() for image in images: i = 255. * image.cpu().numpy() img = Image.fromarray(i.astype(np.uint8)) @@ -640,8 +642,12 @@ def map_filename(filename): if extra_pnginfo is not None: for x in extra_pnginfo: metadata.add_text(x, json.dumps(extra_pnginfo[x])) - img.save(os.path.join(self.output_dir, f"{filename_prefix}_{counter:05}_.png"), pnginfo=metadata, optimize=True) + file = f"{filename_prefix}_{counter:05}_.png" + img.save(os.path.join(self.output_dir, file), pnginfo=metadata, optimize=True) + paths.append(f"/view/{file}") counter += 1 + if server is not None: + server.send_sync("image", {"images": paths, "id": unique_id}) class LoadImage: input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") diff --git a/requirements.txt b/requirements.txt index e4be9ebc2e2..f6656b9d5f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,4 @@ safetensors pytorch_lightning aiohttp accelerate - +python-socketio diff --git a/server.py b/server.py new file mode 100644 index 00000000000..942d2495bd2 --- /dev/null +++ b/server.py @@ -0,0 +1,173 @@ +import os +import sys +import asyncio +import nodes +import main + +try: + import aiohttp + from aiohttp import web +except ImportError: + print("Module 'aiohttp' not installed. Please install it via:") + print("pip install aiohttp") + print("or") + print("pip install -r requirements.txt") + sys.exit() + +try: + import socketio +except ImportError: + print("Module 'python-socketio' not installed. Please install it via:") + print("pip install python-socketio") + print("or") + print("pip install -r requirements.txt") + sys.exit() + + +class PromptServer(): + def __init__(self, loop): + self.prompt_queue = None + self.loop = loop + self.messages = asyncio.Queue() + self.number = 0 + self.app = web.Application() + self.sio = socketio.AsyncServer() + self.sio.attach(self.app) + self.web_root = os.path.join(os.path.dirname( + os.path.realpath(__file__)), "webshit") + routes = web.RouteTableDef() + + @self.sio.event + async def connect(sid, environ): + await self.sio.emit("status", self.get_queue_info(), sid) + + @routes.get("/") + async def get_root(request): + return web.FileResponse(os.path.join(self.web_root, "index.html")) + + @routes.get("/view/{file}") + async def view_image(request): + if "file" in request.match_info: + output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") + file = request.match_info["file"] + file = os.path.splitext(os.path.basename(file))[0] + ".png" + file = os.path.join(output_dir, file) + if os.path.isfile(file): + return web.FileResponse(file) + + return web.Response(status=404) + + @routes.get("/prompt") + async def get_prompt(request): + return web.json_response(self.get_queue_info()) + + @routes.get("/object_info") + async def get_object_info(request): + out = {} + for x in nodes.NODE_CLASS_MAPPINGS: + obj_class = nodes.NODE_CLASS_MAPPINGS[x] + info = {} + info['input'] = obj_class.INPUT_TYPES() + info['output'] = obj_class.RETURN_TYPES + info['name'] = x #TODO + info['description'] = '' + info['category'] = 'sd' + if hasattr(obj_class, 'CATEGORY'): + info['category'] = obj_class.CATEGORY + out[x] = info + return web.json_response(out) + + @routes.get("/queue") + async def get_queue(request): + queue_info = {} + current_queue = self.prompt_queue.get_current_queue() + queue_info['queue_running'] = current_queue[0] + queue_info['queue_pending'] = current_queue[1] + return web.json_response(queue_info) + + @routes.post("/prompt") + async def post_prompt(request): + print("got prompt") + resp_code = 200 + out_string = "" + json_data = await request.json() + + if "number" in json_data: + number = float(json_data['number']) + else: + number = self.number + if "front" in json_data: + if json_data['front']: + number = -number + + self.number += 1 + + if "prompt" in json_data: + prompt = json_data["prompt"] + valid = main.validate_prompt(prompt) + extra_data = {} + if "extra_data" in json_data: + extra_data = json_data["extra_data"] + + if "client_id" in json_data: + extra_data["client_id"] = json_data["client_id"] + if valid[0]: + self.prompt_queue.put((number, id(prompt), prompt, extra_data)) + else: + resp_code = 400 + out_string = valid[1] + print("invalid prompt:", valid[1]) + + return web.Response(body=out_string, status=resp_code) + + @routes.post("/queue") + async def post_queue(request): + json_data = await request.json() + if "clear" in json_data: + if json_data["clear"]: + self.prompt_queue.wipe_queue() + if "delete" in json_data: + to_delete = json_data['delete'] + for id_to_delete in to_delete: + delete_func = lambda a: a[1] == int(id_to_delete) + self.prompt_queue.delete_queue_item(delete_func) + + return web.Response(status=200) + + self.app.add_routes(routes) + self.app.add_routes([ + web.static('/', self.web_root), + ]) + + def get_queue_info(self): + prompt_info = {} + exec_info = {} + exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining() + prompt_info['exec_info'] = exec_info + return prompt_info + + async def send(self, event, data, sid=None): + await self.sio.emit(event, data, to=sid) + + def send_sync(self, event, data, sid=None): + self.loop.call_soon_threadsafe( + self.messages.put_nowait, (event, data, sid)) + + def queue_updated(self): + self.send_sync("status", self.get_queue_info()) + + async def publish_loop(self): + while True: + msg = await self.messages.get() + await self.send(*msg) + + async def start(self, address, port): + runner = web.AppRunner(self.app) + await runner.setup() + site = web.TCPSite(runner, address, port) + await site.start() + + if address == '': + address = '0.0.0.0' + print("Starting server\n") + print("To see the GUI go to: http://{}:{}".format(address, port)) \ No newline at end of file diff --git a/webshit/index.html b/webshit/index.html index 4f26f557550..981a03f348a 100644 --- a/webshit/index.html +++ b/webshit/index.html @@ -2,6 +2,7 @@
+