diff --git a/models/modules/attn_network.py b/models/modules/attn_network.py index a882b22a1..5f6e4d052 100644 --- a/models/modules/attn_network.py +++ b/models/modules/attn_network.py @@ -1,4 +1,6 @@ from torch import nn +import torch.nn.functional as F +import warnings class BaseGenerator_attn(nn.Module): @@ -12,9 +14,24 @@ def compute_outputs(self, input, attentions, images): outputs = [] for i in range(self.nb_mask_attn - self.nb_mask_input): - outputs.append(images[i] * attentions[i]) + if images[i].shape == attentions[i].shape: + outputs.append(images[i] * attentions[i]) + else: + warnings.warn("Bilinear interpolation of attention heads") + rattention = F.interpolate( + attentions[i], size=(images[i].shape[2], images[i].shape[3]) + ) + outputs.append(images[i] * rattention) + for i in range(self.nb_mask_attn - self.nb_mask_input, self.nb_mask_attn): - outputs.append(input * attentions[i]) + if input.shape == attentions[i]: + outputs.append(input * attentions[i]) + else: + warnings.warn("Bilinear interpolation of attention heads") + rattention = F.interpolate( + attentions[i], size=(input.shape[2], input.shape[3]) + ) + outputs.append(input * rattention) return images, attentions, outputs diff --git a/scripts/gen_sliding_images.py b/scripts/gen_sliding_images.py deleted file mode 100644 index 5010812e1..000000000 --- a/scripts/gen_sliding_images.py +++ /dev/null @@ -1,156 +0,0 @@ -import sys -import os -import glob - -sys.path.append("../") -from models import networks -from options.train_options import TrainOptions -import cv2 -import torch -from torchvision import transforms -from torchvision.utils import save_image -import numpy as np -import argparse - - -def sliding_window(image, stepSize, windowSize): - # slide a window across the image - for y in range(0, image.shape[0], stepSize): - for x in range(0, image.shape[1], stepSize): - # yield the current window - yield (x, y, image[y : y + windowSize[1], x : x + windowSize[0]]) - - -parser = argparse.ArgumentParser() -parser.add_argument( - "--model-in-file", help="file path to generator model (.pth file)", required=True -) -parser.add_argument( - "--model-type", - default="mobile_resnet_9blocks", - help="model type, e.g. mobile_resnet_9blocks", -) -parser.add_argument( - "--model-config", - help="optional model configuration, e.g /path/to/segformer_config_b0.py", -) -parser.add_argument( - "--padding-type", - type=str, - help="whether to use padding, zeros or reflect", - default="reflect", -) -parser.add_argument("--img-size", default=256, type=int, help="square image size") -parser.add_argument( - "--img-in", help="image or image folder to transform", required=True -) -parser.add_argument( - "--stepsize", - type=int, - default=128, - help="sliding window stepsize, to be set to image input size", -) -parser.add_argument("--windowsize", type=int, default=256, help="window input size") -parser.add_argument( - "--output-dir", help="full size pictures output directory", required=True -) -parser.add_argument("--bw", action="store_true", help="whether input/output is bw") -parser.add_argument("--cpu", action="store_true", help="whether to use CPU") -args = parser.parse_args() - -if args.bw: - input_nc = output_nc = 1 -else: - input_nc = output_nc = 3 - -# loading model -opt = TrainOptions().parse_json({}) -opt.data_crop_size = args.img_size -opt.data_load_size = args.img_size -opt.G_attn_nb_mask_attn = 10 -opt.G_attn_nb_mask_input = 1 -opt.G_netG = args.model_type -opt.G_padding_type = args.padding_type -opt.model_input_nc = input_nc -opt.model_output_nc = output_nc -if "segformer" in args.model_type: - opt.G_config_segformer = ( - args.model_config - ) # e.g. '/path/to/models/configs/segformer/segformer_config_b0.py' -opt.jg_dir = os.path.join("/".join(__file__.split("/")[:-2])) -model = networks.define_G(**vars(opt)) - -model.eval() -model.load_state_dict(torch.load(args.model_in_file)) - -if not args.cpu: - model = model.cuda() - -if os.path.isfile(args.img_in): - images = [args.img_in] -else: - images = glob.glob(args.img_in + "*.*") - -# preprocessing transforms -tranlist = [ - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), -] -tran = transforms.Compose(tranlist) - -for image in images: - - # reading image - img = cv2.imread(args.img_in) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - print(image, " / shape=", img.shape) - - outputmap = np.zeros((img.shape[0], img.shape[1], 3), np.uint8) - - # - walk through sliding windows - i = 0 - for (x, y, window) in sliding_window( - img, stepSize=args.stepsize, windowSize=(args.windowsize, args.windowsize) - ): - - # - if window is smaller than input sizes, fill it up correctly - windowtmp = window.copy() - resized = False - if window.shape[0] != args.stepsize or window.shape[1] != args.stepsize: - resized = True - windowfull = np.zeros((args.windowsize, args.windowsize, 3), np.uint8) - windowfull[0 : window.shape[0], 0 : window.shape[1]] = window.copy() - window = windowfull - - # - get the local image window - # windowpath = '/tmp/img'+str(i)+'.png' - # cv2.imwrite(windowpath, window) - - # preprocessing - img_tensor = tran(window) - if not args.cpu: - img_tensor = img_tensor.cuda() - - # run through model - out_tensor = model(img_tensor.unsqueeze(0))[0].detach() - - # post-processing - out_img = out_tensor.data.cpu().float().numpy() - print(out_img.shape) - out_img = (np.transpose(out_img, (1, 2, 0)) + 1) / 2.0 * 255.0 - # print(out_img) - out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR) - - # - combine the output images - if resized: - out_img = out_img[0 : windowtmp.shape[0], 0 : windowtmp.shape[1]] - outputmap[y : y + window.shape[1], x : x + window.shape[0]] = out_img - - i += 1 - - # save the full size output image - imgoutpath = ( - args.output_dir + "/" + os.path.basename(image).replace(".jpg", "") + "_gan.jpg" - ) - cv2.imwrite(imgoutpath, outputmap) - print("Successfully generated image ", imgoutpath)