diff --git a/.gitmodules b/.gitmodules index 6b450dd..6a800cb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "extern/SadTalker"] path = extern/SadTalker url = https://github.com/OpenTalker/SadTalker.git +[submodule "extern/GFPGAN"] + path = extern/GFPGAN + url = https://github.com/TencentARC/GFPGAN.git diff --git a/__init__.py b/__init__.py index 4f19a16..977694d 100644 --- a/__init__.py +++ b/__init__.py @@ -1,10 +1,11 @@ import traceback -from .log import log, blue_text, get_summary, get_label +from .log import log, blue_text, cyan_text, get_summary, get_label from .utils import here import importlib import os NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {} NODE_CLASS_MAPPINGS_DEBUG = {} @@ -64,16 +65,17 @@ def load_nodes(): nodes = load_nodes() for node_class in nodes: class_name = node_class.__name__ - class_name = node_class.__name__ - node_name = f"{get_label(class_name)} (mtb)" - NODE_CLASS_MAPPINGS[node_name] = node_class - NODE_CLASS_MAPPINGS_DEBUG[node_name] = node_class.__doc__ - + node_label = f"{get_label(class_name)} (mtb)" + NODE_CLASS_MAPPINGS[node_label] = node_class + NODE_DISPLAY_NAME_MAPPINGS[class_name] = node_label + NODE_CLASS_MAPPINGS_DEBUG[node_label] = node_class.__doc__ + # TODO: I removed this, I find it more convenient to write without spaces, but it breaks every of my workflows + # TODO (cont): and until I find a way to automate the conversion, I'll leave it like this -log.debug( +log.info( f"Loaded the following nodes:\n\t" + "\n\t".join( - f"{k}: {blue_text(get_summary(doc)) if doc else '-'}" + f"{cyan_text(k)}: {blue_text(get_summary(doc)) if doc else '-'}" for k, doc in NODE_CLASS_MAPPINGS_DEBUG.items() ) ) diff --git a/log.py b/log.py index 1924bcc..c538bff 100644 --- a/log.py +++ b/log.py @@ -1,9 +1,21 @@ import logging import re +import os + +base_log_level = logging.DEBUG if os.environ.get("MTB_DEBUG") else logging.INFO +print(f"Log level: {base_log_level}") + + +# Custom object that discards the output +class NullWriter: + def write(self, text): + pass class Formatter(logging.Formatter): grey = "\x1b[38;20m" + cyan = "\x1b[36;20m" + purple = "\x1b[35;20m" yellow = "\x1b[33;20m" red = "\x1b[31;20m" bold_red = "\x1b[31;1m" @@ -12,8 +24,8 @@ class Formatter(logging.Formatter): format = "[%(name)s] | %(levelname)s -> %(message)s" FORMATS = { - logging.DEBUG: grey + format + reset, - logging.INFO: grey + format + reset, + logging.DEBUG: purple + format + reset, + logging.INFO: cyan + format + reset, logging.WARNING: yellow + format + reset, logging.ERROR: red + format + reset, logging.CRITICAL: bold_red + format + reset, @@ -25,21 +37,26 @@ def format(self, record): return formatter.format(record) -def mklog(name, level=logging.DEBUG): +def mklog(name, level=base_log_level): logger = logging.getLogger(name) logger.setLevel(level) - # create console handler with a higher log level - ch = logging.StreamHandler() - ch.setLevel(logging.DEBUG) - ch.setFormatter(Formatter()) + for handler in logger.handlers: + logger.removeHandler(handler) + ch = logging.StreamHandler() + ch.setLevel(level) + ch.setFormatter(Formatter()) logger.addHandler(ch) + + # Disable log propagation + logger.propagate = False + return logger # - The main app logger -log = mklog(__package__) +log = mklog(__package__, base_log_level) def log_user(arg): @@ -54,6 +71,10 @@ def blue_text(text): return f"\033[94m{text}\033[0m" +def cyan_text(text): + return f"\033[96m{text}\033[0m" + + def get_label(label): words = re.findall(r"(?:^|[A-Z])[a-z]*", label) return " ".join(words).strip() diff --git a/nodes/faceenhance.py b/nodes/faceenhance.py new file mode 100644 index 0000000..0f1bde9 --- /dev/null +++ b/nodes/faceenhance.py @@ -0,0 +1,235 @@ +import logging +from gfpgan import GFPGANer +import cv2 +import numpy as np +import os +from pathlib import Path +import folder_paths +from basicsr.utils import imwrite +from PIL import Image +from ..utils import pil2tensor, tensor2pil, np2tensor, tensor2np +import torch +from munch import Munch +from ..log import NullWriter, log +from comfy import model_management +import comfy +from typing import Tuple + + +class LoadFaceEnhanceModel: + def __init__(self) -> None: + pass + + @classmethod + def get_models_root(cls): + return Path(folder_paths.models_dir) / "upscale_models" + + @classmethod + def get_models(cls): + models_path = cls.get_models_root() + + return [ + x + for x in models_path.iterdir() + if x.name.endswith(".pth") + and ("GFPGAN" in x.name or "RestoreFormer" in x.name) + ] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model_name": ( + [x.name for x in cls.get_models()], + {"default": "None"}, + ), + "upscale": ("INT", {"default": 2}), + }, + "optional": {"bg_upsampler": ("UPSCALE_MODEL", {"default": None})}, + } + + RETURN_TYPES = ("FACEENHANCE_MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "load_model" + CATEGORY = "face" + + def load_model(self, model_name, upscale=2, bg_upsampler=None): + basic = "RestoreFormer" not in model_name + + root = self.get_models_root() + + if bg_upsampler is not None: + log.warning( + f"Upscale value overridden to {bg_upsampler.scale} from bg_upsampler" + ) + upscale = bg_upsampler.scale + bg_upsampler = BGUpscaleWrapper(bg_upsampler) + + sys.stdout = NullWriter() + model = GFPGANer( + model_path=(root / model_name).as_posix(), + upscale=upscale, + arch="clean" if basic else "RestoreFormer", # or original for v1.0 only + channel_multiplier=2, # 1 for v1.0 only + bg_upsampler=bg_upsampler, + ) + + sys.stdout = sys.__stdout__ + return (model,) + + +class BGUpscaleWrapper: + def __init__(self, upscale_model) -> None: + self.upscale_model = upscale_model + + def enhance(self, img: Image.Image, outscale=2): + device = model_management.get_torch_device() + self.upscale_model.to(device) + + tile = 128 + 64 + overlap = 8 + + imgt = np2tensor(img) + imgt = imgt.movedim(-1, -3).to(device) + + steps = imgt.shape[0] * comfy.utils.get_tiled_scale_steps( + imgt.shape[3], imgt.shape[2], tile_x=tile, tile_y=tile, overlap=overlap + ) + + log.debug(f"Steps: {steps}") + + pbar = comfy.utils.ProgressBar(steps) + + s = comfy.utils.tiled_scale( + imgt, + lambda a: self.upscale_model(a), + tile_x=tile, + tile_y=tile, + overlap=overlap, + upscale_amount=self.upscale_model.scale, + pbar=pbar, + ) + + self.upscale_model.cpu() + s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0) + return (tensor2np(s),) + + +import sys + + +class RestoreFace: + def __init__(self) -> None: + pass + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "restore" + CATEGORY = "face" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "model": ("FACEENHANCE_MODEL",), + # Input are aligned faces + "aligned": (["true", "false"], {"default": "false"}), + # Only restore the center face + "only_center_face": (["true", "false"], {"default": "false"}), + # Adjustable weights + "weight": ("FLOAT", {"default": 0.5}), + "save_tmp_steps": (["true", "false"], {"default": "true"}), + } + } + + def do_restore( + self, + image: torch.Tensor, + model: GFPGANer, + aligned, + only_center_face, + weight, + save_tmp_steps, + ) -> torch.Tensor: + pimage = tensor2pil(image) + width, height = pimage.size + + source_img = cv2.cvtColor(np.array(pimage), cv2.COLOR_RGB2BGR) + + sys.stdout = NullWriter() + cropped_faces, restored_faces, restored_img = model.enhance( + source_img, + has_aligned=aligned, + only_center_face=only_center_face, + paste_back=True, + # TODO: weight has no effect in 1.3 and 1.4 (only tested these for now...) + weight=weight, + ) + sys.stdout = sys.__stdout__ + log.warning(f"Weight value has no effect for now. (value: {weight})") + + if save_tmp_steps: + self.save_intermediate_images(cropped_faces, restored_faces, height, width) + output = None + if restored_img is not None: + output = Image.fromarray(cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)) + # imwrite(restored_img, save_restore_path) + + return pil2tensor(output) + + def restore( + self, + image: torch.Tensor, + model: GFPGANer, + aligned="false", + only_center_face="false", + weight=0.5, + save_tmp_steps="true", + ) -> Tuple[torch.Tensor]: + save_tmp_steps = save_tmp_steps == "true" + aligned = aligned == "true" + only_center_face = only_center_face == "true" + + out = [ + self.do_restore( + image[i], model, aligned, only_center_face, weight, save_tmp_steps + ) + for i in range(image.size(0)) + ] + + return (torch.cat(out, dim=0),) + + def get_step_image_path(self, step, idx): + ( + full_output_folder, + filename, + counter, + _subfolder, + _filename_prefix, + ) = folder_paths.get_save_image_path( + f"{step}_{idx:03}", + folder_paths.temp_directory, + ) + file = f"{filename}_{counter:05}_.png" + + return os.path.join(full_output_folder, file) + + def save_intermediate_images(self, cropped_faces, restored_faces, height, width): + for idx, (cropped_face, restored_face) in enumerate( + zip(cropped_faces, restored_faces) + ): + face_id = idx + 1 + file = self.get_step_image_path("cropped_faces", face_id) + imwrite(cropped_face, file) + + file = self.get_step_image_path("cropped_faces_restored", face_id) + imwrite(restored_face, file) + + file = self.get_step_image_path("cropped_faces_compare", face_id) + + # save comparison image + cmp_img = np.concatenate((cropped_face, restored_face), axis=1) + imwrite(cmp_img, file) + + +__nodes__ = [RestoreFace, LoadFaceEnhanceModel] diff --git a/nodes/faceswap.py b/nodes/faceswap.py index 7b504fb..dc1e3f3 100644 --- a/nodes/faceswap.py +++ b/nodes/faceswap.py @@ -1,5 +1,6 @@ # region imports from ifnude import detect +import onnxruntime from pathlib import Path from PIL import Image from typing import List, Set, Tuple @@ -8,18 +9,58 @@ import glob import insightface import numpy as np -import onnxruntime import os import tempfile import torch - +from insightface.model_zoo.inswapper import INSwapper from ..utils import pil2tensor, tensor2pil -from ..log import mklog +from ..log import mklog, NullWriter +import sys # endregion -logger = mklog(__name__) -providers = onnxruntime.get_available_providers() +log = mklog(__name__) + + +class LoadFaceSwapModel: + """Loads a faceswap model""" + + @staticmethod + def get_models() -> List[Path]: + models_path = os.path.join(folder_paths.models_dir, "insightface/*") + models = glob.glob(models_path) + models = [Path(x) for x in models if x.endswith(".onnx") or x.endswith(".pth")] + return models + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "faceswap_model": ( + [x.name for x in cls.get_models()], + {"default": "None"}, + ), + }, + } + + RETURN_TYPES = ("FACESWAP_MODEL",) + FUNCTION = "load_model" + CATEGORY = "face" + + def load_model(self, faceswap_model: str): + model_path = os.path.join( + folder_paths.models_dir, "insightface", faceswap_model + ) + log.info(f"Loading model {model_path}") + return ( + INSwapper( + model_path, + onnxruntime.InferenceSession( + path_or_bytes=model_path, + providers=onnxruntime.get_available_providers(), + ), + ), + ) # region roop node @@ -32,13 +73,6 @@ class FaceSwap: def __init__(self) -> None: pass - @staticmethod - def get_models() -> List[Path]: - models_path = os.path.join(folder_paths.models_dir, "insightface/*") - models = glob.glob(models_path) - models = [Path(x) for x in models if x.endswith(".onnx") or x.endswith(".pth")] - return models - @classmethod def INPUT_TYPES(cls): return { @@ -46,10 +80,8 @@ def INPUT_TYPES(cls): "image": ("IMAGE",), "reference": ("IMAGE",), "faces_index": ("STRING", {"default": "0"}), - "faceswap_model": ( - [x.name for x in cls.get_models()], - {"default": "None"}, - ), + "faceswap_model": ("FACESWAP_MODEL", {"default": "None"}), + "allow_nsfw": (["true", "false"], {"default": "false"}), }, "optional": {"debug": (["true", "false"], {"default": "false"})}, } @@ -63,8 +95,9 @@ def swap( image: torch.Tensor, reference: torch.Tensor, faces_index: str, - faceswap_model: str, - debug: str, + faceswap_model, + allow_nsfw="fase", + debug="false", ): def do_swap(img): img = tensor2pil(img) @@ -72,13 +105,16 @@ def do_swap(img): face_ids = { int(x) for x in faces_index.strip(",").split(",") if x.isnumeric() } - model = self.getFaceSwapModel(faceswap_model) - swapped = swap_face(ref, img, model, face_ids) + sys.stdout = NullWriter() + swapped = swap_face( + ref, img, faceswap_model, face_ids, allow_nsfw == "true" + ) + sys.stdout = sys.__stdout__ return pil2tensor(swapped) batch_count = image.size(0) - logger.info(f"Running insightface swap (batch size: {batch_count})") + log.info(f"Running insightface swap (batch size: {batch_count})") if reference.size(0) != 1: raise ValueError("Reference image must have batch size 1") @@ -91,31 +127,20 @@ def do_swap(img): return (image,) - def getFaceSwapModel(self, model_path: str): - model_path = os.path.join(folder_paths.models_dir, "insightface", model_path) - if self.model_path is None or self.model_path != model_path: - logger.info(f"Loading model {model_path}") - self.model_path = model_path - self.model = insightface.model_zoo.get_model( - model_path, providers=providers - ) - else: - logger.info("Using cached model") - - logger.info("Model loaded") - return self.model - # endregion # region face swap utils def get_face_single(img_data: np.ndarray, face_index=0, det_size=(640, 640)): - face_analyser = insightface.app.FaceAnalysis(name="buffalo_l", providers=providers) + face_analyser = insightface.app.FaceAnalysis( + name="buffalo_l", root=os.path.join(folder_paths.models_dir, "insightface") + ) face_analyser.prepare(ctx_id=0, det_size=det_size) face = face_analyser.get(img_data) if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: + log.debug("No face detected, trying again with smaller image") det_size_half = (det_size[0] // 2, det_size[1] // 2) return get_face_single(img_data, face_index=face_index, det_size=det_size_half) @@ -136,14 +161,19 @@ def swap_face( target_img: Image.Image, face_swapper_model=None, faces_index: Set[int] = None, + allow_nsfw=False, ) -> Image.Image: if faces_index is None: faces_index = {0} - logger.info(f"Swapping faces: {faces_index}") + log.debug(f"Swapping faces: {faces_index}") result_image = target_img converted = convert_to_sd(target_img) - scale, fn = converted[0], converted[1] - if face_swapper_model is not None and not scale: + nsfw, fn = converted[0], converted[1] + + if nsfw and allow_nsfw: + nsfw = False + + if face_swapper_model is not None and not nsfw: if isinstance(source_img, str): # source_img is a base64 string import base64, io @@ -165,19 +195,21 @@ def swap_face( for face_num in faces_index: target_face = get_face_single(target_img, face_index=face_num) if target_face is not None: + sys.stdout = NullWriter() result = face_swapper_model.get(result, target_face, source_face) + sys.stdout = sys.__stdout__ else: - logger.warning(f"No target face found for {face_num}") + log.warning(f"No target face found for {face_num}") result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) else: - logger.warning("No source face found") + log.warning("No source face found") else: - logger.error("No face swap model provided") + log.error("No face swap model provided") return result_image # endregion face swap utils -__nodes__ = [FaceSwap] +__nodes__ = [FaceSwap, LoadFaceSwapModel] diff --git a/nodes/image_processing.py b/nodes/image_processing.py index deb83de..02abba5 100644 --- a/nodes/image_processing.py +++ b/nodes/image_processing.py @@ -6,7 +6,7 @@ import numpy as np import torchvision.transforms.functional as F from PIL import Image, ImageChops -from ..utils import tensor2pil, pil2tensor, img_np_to_tensor, img_tensor_to_np +from ..utils import tensor2pil, pil2tensor, np2tensor, tensor2np import cv2 import torch from ..log import log @@ -18,7 +18,7 @@ try: from cv2.ximgproc import guidedFilter except ImportError: - log.error("guidedFilter not found, use opencv-contrib-python") + log.warning("cv2.ximgproc.guidedFilter not found, use opencv-contrib-python") class ColorCorrect: @@ -362,7 +362,7 @@ def INPUT_TYPES(cls): FUNCTION = "deglaze_image" def deglaze_image(self, image): - return (img_np_to_tensor(deglaze_np_img(img_tensor_to_np(image))),) + return (np2tensor(deglaze_np_img(tensor2np(image))),) class MaskToImage: @@ -388,7 +388,7 @@ def INPUT_TYPES(cls): FUNCTION = "render_mask" def render_mask(self, mask, color, background): - mask = img_tensor_to_np(mask) + mask = tensor2np(mask) mask = Image.fromarray(mask).convert("L") image = Image.new("RGBA", mask.size, color=color) diff --git a/requirements.txt b/requirements.txt index fca2f7a..f56dd43 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,6 @@ ifnude==0.0.3 insightface==0.7.3 mmcv==2.0.0 mmdet==3.0.0 -rembg==2.0.37 \ No newline at end of file +rembg==2.0.37 +facexlib==0.3.0 +basicsr==1.4.2 \ No newline at end of file diff --git a/scripts/download_models.py b/scripts/download_models.py index 084595a..a048855 100644 --- a/scripts/download_models.py +++ b/scripts/download_models.py @@ -26,6 +26,18 @@ ], "destination": "insightface", }, + "GFPGAN (face enhancement)": { + "size": 332, + "download_url": [ + "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth", + # TODO: provide a way to selectively download models from "packs" + # https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth + # https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth + # https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth + # https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth + ], + "destination": "upscale_models", + }, } console = Console() diff --git a/utils.py b/utils.py index f8b6bc2..50fb5d7 100644 --- a/utils.py +++ b/utils.py @@ -4,23 +4,23 @@ from pathlib import Path import sys + def add_path(path, prepend=False): - if isinstance(path, list): for p in path: add_path(p, prepend) return - + if isinstance(path, Path): path = path.resolve().as_posix() - + if path not in sys.path: if prepend: sys.path.insert(0, path) else: sys.path.append(path) - - + + # Get the absolute path of the parent directory of the current script here = Path(__file__).parent.resolve() @@ -31,13 +31,18 @@ def add_path(path, prepend=False): font_path = here / "font.ttf" # Add extern folder to path -add_path(here / "extern") -add_path(here / "extern" / "SadTalker") +extern_root = here / "extern" +add_path(extern_root) +for pth in extern_root.iterdir(): + if pth.is_dir(): + add_path(pth) + # Add the ComfyUI directory and custom nodes path to the sys.path list add_path(comfy_dir) add_path((comfy_dir / "custom_nodes")) + # Tensor to PIL (grabbed from WAS Suite) def tensor2pil(image: torch.Tensor) -> Image.Image: return Image.fromarray( @@ -45,16 +50,30 @@ def tensor2pil(image: torch.Tensor) -> Image.Image: ) +# TODO: write pil2tensor counterpart (batch support) +# def tensor2pil(image: torch.Tensor) -> Union[Image.Image, List[Image.Image]]: +# batch_count = 1 +# if len(image.shape) > 3: +# batch_count = image.size(0) + +# if batch_count == 1: +# return Image.fromarray( +# np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) +# ) +# return [tensor2pil(image[i]) for i in range(batch_count)] + + # Convert PIL to Tensor (grabbed from WAS Suite) def pil2tensor(image: Image.Image) -> torch.Tensor: return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) -def img_np_to_tensor(img_np): - return torch.from_numpy(img_np / 255.0)[None,] -def img_tensor_to_np(img_tensor): - img_tensor = img_tensor.clone() - img_tensor = img_tensor * 255.0 - return img_tensor.squeeze(0).numpy().astype(np.float32) +def np2tensor(img_np): + return torch.from_numpy(img_np.astype(np.float32) / 255.0).unsqueeze(0) + + +def tensor2np(tensor: torch.Tensor) -> np.ndarray: + return np.clip(255.0 * tensor.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) + diff --git a/web/color_widget.js b/web/color_widget.js index fb41117..4bdca72 100644 --- a/web/color_widget.js +++ b/web/color_widget.js @@ -4,34 +4,34 @@ import { app } from "/scripts/app.js"; import { ComfyWidgets } from "/scripts/widgets.js"; export function CUSTOM_INT(node, inputName, val, func, config = {}) { - return { - widget: node.addWidget( - "number", - inputName, - val, - func, - Object.assign({}, { min: 0, max: 4096, step: 640, precision: 0 }, config) - ), - }; + return { + widget: node.addWidget( + "number", + inputName, + val, + func, + Object.assign({}, { min: 0, max: 4096, step: 640, precision: 0 }, config) + ), + }; } -const dumb_call = (v,d,node) => { - console.log("dumb_call", {v,d,node}); +const dumb_call = (v, d, node) => { + console.log("dumb_call", { v, d, node }); } -function isColorBright (rgb, threshold=240) { +function isColorBright(rgb, threshold = 240) { const brightess = getBrightness(rgb) - + return brightess > threshold - } - - function getBrightness (rgbObj) { - return Math.round(((parseInt(rgbObj[0]) * 299) + (parseInt(rgbObj[1]) * 587) + (parseInt(rgbObj[2]) * 114)) /1000) - } - +} + +function getBrightness(rgbObj) { + return Math.round(((parseInt(rgbObj[0]) * 299) + (parseInt(rgbObj[1]) * 587) + (parseInt(rgbObj[2]) * 114)) / 1000) +} + /** * @returns {import("/types/litegraph").IWidget} widget */ -const custom = (key,val) => { +const custom = (key, val, compute = false) => { /** @type {import("/types/litegraph").IWidget} */ const widget = {} // widget.y = 0; @@ -45,7 +45,7 @@ const custom = (key,val) => { widgetY, height) { const border = 3; - + // draw a rect with a border and a fill color ctx.fillStyle = "#000"; ctx.fillRect(0, widgetY, widgetWidth, height); @@ -68,8 +68,8 @@ const custom = (key,val) => { // ctx.strokeStyle = "#fff"; // ctx.strokeRect(border, widgetY + border, widgetWidth - border * 2, height - border * 2); - - + + // ctx.fillStyle = "#000"; // ctx.fillRect(widgetWidth/2 - border / 2 , widgetY + border / 2 , widgetWidth/2 + border / 2, height + border / 2); // ctx.fillStyle = this.value; @@ -78,118 +78,117 @@ const custom = (key,val) => { } widget.mouse = function (e, pos, node) { if (e.type === "pointerdown") { - console.log({e,pos,node}) - // get widgets of type type : "COLOR" - const widgets = node.widgets.filter(w => w.type === "COLOR"); - - for (const w of widgets) { - // color picker - const rect = [w.last_y, w.last_y + 32]; - console.log({rect,pos}) - if (pos[1] > rect[0] && pos[1] < rect[1]) { - console.log("color picker", node) - const picker = document.createElement("input"); - picker.type = "color"; - picker.value = this.value; - // picker.style.position = "absolute"; - // picker.style.left = ( pos[0]) + "px"; - // picker.style.top = ( pos[1]) + "px"; - - // place at screen center - // picker.style.position = "absolute"; - // picker.style.left = (window.innerWidth / 2) + "px"; - // picker.style.top = (window.innerHeight / 2) + "px"; - // picker.style.transform = "translate(-50%, -50%)"; - // picker.style.zIndex = 1000; - - - - document.body.appendChild(picker); - - picker.addEventListener("change", () => { - this.value = picker.value; - node.graph._version++; - node.setDirtyCanvas(true, true); - document.body.removeChild(picker); - }); - - // simulate click with screen center - const pointer_event = new MouseEvent('click', { - bubbles: false, - // cancelable: true, - pointerType: "mouse", - clientX: window.innerWidth / 2, - clientY: window.innerHeight / 2, - x: window.innerWidth / 2, - y: window.innerHeight / 2, - offsetX: window.innerWidth / 2, - offsetY: window.innerHeight / 2, - screenX: window.innerWidth / 2, - screenY: window.innerHeight / 2, - - - }); - console.log(e) - picker.dispatchEvent(pointer_event); - - }}}} + // get widgets of type type : "COLOR" + const widgets = node.widgets.filter(w => w.type === "COLOR"); + + for (const w of widgets) { + // color picker + const rect = [w.last_y, w.last_y + 32]; + if (pos[1] > rect[0] && pos[1] < rect[1]) { + console.log("color picker", node) + const picker = document.createElement("input"); + picker.type = "color"; + picker.value = this.value; + // picker.style.position = "absolute"; + // picker.style.left = ( pos[0]) + "px"; + // picker.style.top = ( pos[1]) + "px"; + + // place at screen center + // picker.style.position = "absolute"; + // picker.style.left = (window.innerWidth / 2) + "px"; + // picker.style.top = (window.innerHeight / 2) + "px"; + // picker.style.transform = "translate(-50%, -50%)"; + // picker.style.zIndex = 1000; + + + + document.body.appendChild(picker); + + picker.addEventListener("change", () => { + this.value = picker.value; + node.graph._version++; + node.setDirtyCanvas(true, true); + document.body.removeChild(picker); + }); + + // simulate click with screen center + const pointer_event = new MouseEvent('click', { + bubbles: false, + // cancelable: true, + pointerType: "mouse", + clientX: window.innerWidth / 2, + clientY: window.innerHeight / 2, + x: window.innerWidth / 2, + y: window.innerHeight / 2, + offsetX: window.innerWidth / 2, + offsetY: window.innerHeight / 2, + screenX: window.innerWidth / 2, + screenY: window.innerHeight / 2, + + + }); + console.log(e) + picker.dispatchEvent(pointer_event); + + } + } + } + } widget.computeSize = function (width) { return [width, 32]; } + return widget; } app.registerExtension({ name: "mtb.ColorPicker", - init: () => { - ComfyWidgets.COLOR = function () { - return { - widget:custom("color", "#ff0000") - }; - }; - }, + async beforeRegisterNodeDef(nodeType, nodeData, app) { - + //console.log("mtb.ColorPicker", { nodeType, nodeData, app }); const rinputs = nodeData.input?.required; // object with key/value pairs, "0" is the type // console.log(nodeData.name, { nodeType, nodeData, app }); if (!rinputs) return; - + let has_color = false; for (const [key, input] of Object.entries(rinputs)) { if (input[0] === "COLOR") { - has_color = true; + has_color = true; // input[1] = { default: "#ff0000" }; - - }} + + } + } if (!has_color) return; - + const onNodeCreated = nodeType.prototype.onNodeCreated; nodeType.prototype.onNodeCreated = function () { const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; this.serialize_widgets = true; - // if (rinputs[0] === "COLOR") { - // console.log(nodeData.name, { nodeType, nodeData, app }); + // if (rinputs[0] === "COLOR") { + // console.log(nodeData.name, { nodeType, nodeData, app }); - // loop through the inputs to find the color inputs - for (const [key, input] of Object.entries(rinputs)) { - if (input[0] === "COLOR") { - this.addCustomWidget(custom(key,input[1])) - } - // } - } - - this.onRemoved = function () { - // When removing this node we need to remove the input from the DOM - for (let y in this.widgets) { - if (this.widgets[y].canvas) { - this.widgets[y].canvas.remove(); + // loop through the inputs to find the color inputs + for (const [key, input] of Object.entries(rinputs)) { + if (input[0] === "COLOR") { + let widget = custom(key, input[1]) + + this.addCustomWidget(widget) } + // } } - }; + + this.onRemoved = function () { + // When removing this node we need to remove the input from the DOM + for (let y in this.widgets) { + if (this.widgets[y].canvas) { + this.widgets[y].canvas.remove(); + } + } + }; } } });