Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore Face node #8

Merged
merged 12 commits into from
Jul 5, 2023
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "extern/SadTalker"]
path = extern/SadTalker
url = https://github.com/OpenTalker/SadTalker.git
[submodule "extern/GFPGAN"]
path = extern/GFPGAN
url = https://github.com/TencentARC/GFPGAN.git
18 changes: 10 additions & 8 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import traceback
from .log import log, blue_text, get_summary, get_label
from .log import log, blue_text, cyan_text, get_summary, get_label
from .utils import here
import importlib
import os

NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
NODE_CLASS_MAPPINGS_DEBUG = {}


Expand Down Expand Up @@ -64,16 +65,17 @@ def load_nodes():
nodes = load_nodes()
for node_class in nodes:
class_name = node_class.__name__
class_name = node_class.__name__
node_name = f"{get_label(class_name)} (mtb)"
NODE_CLASS_MAPPINGS[node_name] = node_class
NODE_CLASS_MAPPINGS_DEBUG[node_name] = node_class.__doc__

node_label = f"{get_label(class_name)} (mtb)"
NODE_CLASS_MAPPINGS[node_label] = node_class
NODE_DISPLAY_NAME_MAPPINGS[class_name] = node_label
NODE_CLASS_MAPPINGS_DEBUG[node_label] = node_class.__doc__
# TODO: I removed this, I find it more convenient to write without spaces, but it breaks every of my workflows
# TODO (cont): and until I find a way to automate the conversion, I'll leave it like this

log.debug(
log.info(
f"Loaded the following nodes:\n\t"
+ "\n\t".join(
f"{k}: {blue_text(get_summary(doc)) if doc else '-'}"
f"{cyan_text(k)}: {blue_text(get_summary(doc)) if doc else '-'}"
for k, doc in NODE_CLASS_MAPPINGS_DEBUG.items()
)
)
37 changes: 29 additions & 8 deletions log.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import logging
import re
import os

base_log_level = logging.DEBUG if os.environ.get("MTB_DEBUG") else logging.INFO
print(f"Log level: {base_log_level}")


# Custom object that discards the output
class NullWriter:
def write(self, text):
pass


class Formatter(logging.Formatter):
grey = "\x1b[38;20m"
cyan = "\x1b[36;20m"
purple = "\x1b[35;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
bold_red = "\x1b[31;1m"
Expand All @@ -12,8 +24,8 @@ class Formatter(logging.Formatter):
format = "[%(name)s] | %(levelname)s -> %(message)s"

FORMATS = {
logging.DEBUG: grey + format + reset,
logging.INFO: grey + format + reset,
logging.DEBUG: purple + format + reset,
logging.INFO: cyan + format + reset,
logging.WARNING: yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset,
Expand All @@ -25,21 +37,26 @@ def format(self, record):
return formatter.format(record)


def mklog(name, level=logging.DEBUG):
def mklog(name, level=base_log_level):
logger = logging.getLogger(name)
logger.setLevel(level)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)

ch.setFormatter(Formatter())
for handler in logger.handlers:
logger.removeHandler(handler)

ch = logging.StreamHandler()
ch.setLevel(level)
ch.setFormatter(Formatter())
logger.addHandler(ch)

# Disable log propagation
logger.propagate = False

return logger


# - The main app logger
log = mklog(__package__)
log = mklog(__package__, base_log_level)


def log_user(arg):
Expand All @@ -54,6 +71,10 @@ def blue_text(text):
return f"\033[94m{text}\033[0m"


def cyan_text(text):
return f"\033[96m{text}\033[0m"


def get_label(label):
words = re.findall(r"(?:^|[A-Z])[a-z]*", label)
return " ".join(words).strip()
235 changes: 235 additions & 0 deletions nodes/faceenhance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import logging
from gfpgan import GFPGANer
import cv2
import numpy as np
import os
from pathlib import Path
import folder_paths
from basicsr.utils import imwrite
from PIL import Image
from ..utils import pil2tensor, tensor2pil, np2tensor, tensor2np
import torch
from munch import Munch
from ..log import NullWriter, log
from comfy import model_management
import comfy
from typing import Tuple


class LoadFaceEnhanceModel:
def __init__(self) -> None:
pass

@classmethod
def get_models_root(cls):
return Path(folder_paths.models_dir) / "upscale_models"

@classmethod
def get_models(cls):
models_path = cls.get_models_root()

return [
x
for x in models_path.iterdir()
if x.name.endswith(".pth")
and ("GFPGAN" in x.name or "RestoreFormer" in x.name)
]

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_name": (
[x.name for x in cls.get_models()],
{"default": "None"},
),
"upscale": ("INT", {"default": 2}),
},
"optional": {"bg_upsampler": ("UPSCALE_MODEL", {"default": None})},
}

RETURN_TYPES = ("FACEENHANCE_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = "face"

def load_model(self, model_name, upscale=2, bg_upsampler=None):
basic = "RestoreFormer" not in model_name

root = self.get_models_root()

if bg_upsampler is not None:
log.warning(
f"Upscale value overridden to {bg_upsampler.scale} from bg_upsampler"
)
upscale = bg_upsampler.scale
bg_upsampler = BGUpscaleWrapper(bg_upsampler)

sys.stdout = NullWriter()
model = GFPGANer(
model_path=(root / model_name).as_posix(),
upscale=upscale,
arch="clean" if basic else "RestoreFormer", # or original for v1.0 only
channel_multiplier=2, # 1 for v1.0 only
bg_upsampler=bg_upsampler,
)

sys.stdout = sys.__stdout__
return (model,)


class BGUpscaleWrapper:
def __init__(self, upscale_model) -> None:
self.upscale_model = upscale_model

def enhance(self, img: Image.Image, outscale=2):
device = model_management.get_torch_device()
self.upscale_model.to(device)

tile = 128 + 64
overlap = 8

imgt = np2tensor(img)
imgt = imgt.movedim(-1, -3).to(device)

steps = imgt.shape[0] * comfy.utils.get_tiled_scale_steps(
imgt.shape[3], imgt.shape[2], tile_x=tile, tile_y=tile, overlap=overlap
)

log.debug(f"Steps: {steps}")

pbar = comfy.utils.ProgressBar(steps)

s = comfy.utils.tiled_scale(
imgt,
lambda a: self.upscale_model(a),
tile_x=tile,
tile_y=tile,
overlap=overlap,
upscale_amount=self.upscale_model.scale,
pbar=pbar,
)

self.upscale_model.cpu()
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
return (tensor2np(s),)


import sys


class RestoreFace:
def __init__(self) -> None:
pass

RETURN_TYPES = ("IMAGE",)
FUNCTION = "restore"
CATEGORY = "face"

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"model": ("FACEENHANCE_MODEL",),
# Input are aligned faces
"aligned": (["true", "false"], {"default": "false"}),
# Only restore the center face
"only_center_face": (["true", "false"], {"default": "false"}),
# Adjustable weights
"weight": ("FLOAT", {"default": 0.5}),
"save_tmp_steps": (["true", "false"], {"default": "true"}),
}
}

def do_restore(
self,
image: torch.Tensor,
model: GFPGANer,
aligned,
only_center_face,
weight,
save_tmp_steps,
) -> torch.Tensor:
pimage = tensor2pil(image)
width, height = pimage.size

source_img = cv2.cvtColor(np.array(pimage), cv2.COLOR_RGB2BGR)

sys.stdout = NullWriter()
cropped_faces, restored_faces, restored_img = model.enhance(
source_img,
has_aligned=aligned,
only_center_face=only_center_face,
paste_back=True,
# TODO: weight has no effect in 1.3 and 1.4 (only tested these for now...)
weight=weight,
)
sys.stdout = sys.__stdout__
log.warning(f"Weight value has no effect for now. (value: {weight})")

if save_tmp_steps:
self.save_intermediate_images(cropped_faces, restored_faces, height, width)
output = None
if restored_img is not None:
output = Image.fromarray(cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB))
# imwrite(restored_img, save_restore_path)

return pil2tensor(output)

def restore(
self,
image: torch.Tensor,
model: GFPGANer,
aligned="false",
only_center_face="false",
weight=0.5,
save_tmp_steps="true",
) -> Tuple[torch.Tensor]:
save_tmp_steps = save_tmp_steps == "true"
aligned = aligned == "true"
only_center_face = only_center_face == "true"

out = [
self.do_restore(
image[i], model, aligned, only_center_face, weight, save_tmp_steps
)
for i in range(image.size(0))
]

return (torch.cat(out, dim=0),)

def get_step_image_path(self, step, idx):
(
full_output_folder,
filename,
counter,
_subfolder,
_filename_prefix,
) = folder_paths.get_save_image_path(
f"{step}_{idx:03}",
folder_paths.temp_directory,
)
file = f"{filename}_{counter:05}_.png"

return os.path.join(full_output_folder, file)

def save_intermediate_images(self, cropped_faces, restored_faces, height, width):
for idx, (cropped_face, restored_face) in enumerate(
zip(cropped_faces, restored_faces)
):
face_id = idx + 1
file = self.get_step_image_path("cropped_faces", face_id)
imwrite(cropped_face, file)

file = self.get_step_image_path("cropped_faces_restored", face_id)
imwrite(restored_face, file)

file = self.get_step_image_path("cropped_faces_compare", face_id)

# save comparison image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
imwrite(cmp_img, file)


__nodes__ = [RestoreFace, LoadFaceEnhanceModel]
Loading