Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop basicsr dependency #14467

Merged
merged 4 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
2>&1 | tee output.txt &
- name: Run tests
run: |
wait-for-it --service 127.0.0.1:7860 -t 600
wait-for-it --service 127.0.0.1:7860 -t 20
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server
if: always()
Expand Down
39 changes: 28 additions & 11 deletions modules/face_restoration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,28 @@
logger = logging.getLogger(__name__)


def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
"""Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
assert img.shape[2] == 3, "image must be RGB"
if img.dtype == "float64":
img = img.astype("float32")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return torch.from_numpy(img.transpose(2, 0, 1)).float()


def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
"""
Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
"""
tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
assert tensor.dim() == 3, "tensor must be RGB"
img_np = tensor.numpy().transpose(1, 2, 0)
if img_np.shape[2] == 1: # gray image, no RGB/BGR required
return np.squeeze(img_np, axis=2)
return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)


def create_face_helper(device) -> FaceRestoreHelper:
from facexlib.detection import retinaface
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
Expand All @@ -36,14 +58,13 @@ def create_face_helper(device) -> FaceRestoreHelper:
def restore_with_face_helper(
np_image: np.ndarray,
face_helper: FaceRestoreHelper,
restore_face: Callable[[np.ndarray], np.ndarray],
restore_face: Callable[[torch.Tensor], torch.Tensor],
) -> np.ndarray:
"""
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.

`restore_face` should take a cropped face image and return a restored face image.
"""
from basicsr.utils import img2tensor, tensor2img
from torchvision.transforms.functional import normalize
np_image = np_image[:, :, ::-1]
original_resolution = np_image.shape[0:2]
Expand All @@ -56,23 +77,19 @@ def restore_with_face_helper(
face_helper.align_warp_face()
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
for cropped_face in face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)

try:
with torch.no_grad():
restored_face = tensor2img(
restore_face(cropped_face_t),
rgb2bgr=True,
min_max=(-1, 1),
)
cropped_face_t = restore_face(cropped_face_t)
devices.torch_gc()
except Exception:
errors.report('Failed face-restoration inference', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))

restored_face = restored_face.astype('uint8')
restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
restored_face = (restored_face * 255.0).astype('uint8')
face_helper.add_restored_face(restored_face)

logger.debug("Merging restored faces into image")
Expand Down Expand Up @@ -126,7 +143,7 @@ def load_net(self) -> torch.Module:
def restore_with_helper(
self,
np_image: np.ndarray,
restore_face: Callable[[np.ndarray], np.ndarray],
restore_face: Callable[[torch.Tensor], torch.Tensor],
) -> np.ndarray:
try:
if self.net is None:
Expand Down
10 changes: 7 additions & 3 deletions modules/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter

from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
import modules.textual_inversion.dataset
Expand Down Expand Up @@ -344,6 +343,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})

def tensorboard_setup(log_directory):
from torch.utils.tensorboard import SummaryWriter
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
return SummaryWriter(
log_dir=os.path.join(log_directory, "tensorboard"),
Expand Down Expand Up @@ -448,8 +448,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed

tensorboard_writer = None
if shared.opts.training_enable_tensorboard:
tensorboard_writer = tensorboard_setup(log_directory)
try:
tensorboard_writer = tensorboard_setup(log_directory)
except ImportError:
errors.report("Error initializing tensorboard", exc_info=True)

pin_memory = shared.opts.pin_memory

Expand Down Expand Up @@ -622,7 +626,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"

if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
if tensorboard_writer and shared.opts.training_tensorboard_save_images:
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)

if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ GitPython
Pillow
accelerate

basicsr
blendmodes
clean-fid
einops
Expand Down
1 change: 0 additions & 1 deletion requirements_versions.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
GitPython==3.1.32
Pillow==9.5.0
accelerate==0.21.0
basicsr==1.4.2
blendmodes==2022
clean-fid==0.1.35
einops==0.4.1
Expand Down