diff --git a/models/base_diffusion_model.py b/models/base_diffusion_model.py new file mode 100644 index 000000000..0b8e03042 --- /dev/null +++ b/models/base_diffusion_model.py @@ -0,0 +1,146 @@ +import os +import copy +import torch +from collections import OrderedDict +from abc import abstractmethod +from .modules.utils import get_scheduler +from torchviz import make_dot +from .base_model import BaseModel + +from util.network_group import NetworkGroup + +# for FID +from data.base_dataset import get_transform +from .modules.fid.pytorch_fid.fid_score import ( + _compute_statistics_of_path, + calculate_frechet_distance, +) +from util.util import save_image, tensor2im +import numpy as np +from util.diff_aug import DiffAugment + + +from inspect import isfunction + + +import torch.nn.functional as F + + +from tqdm import tqdm + + +class BaseDiffusionModel(BaseModel): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt, rank): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this fucntion, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): specify the images that you want to display and save. + -- self.visual_names (str list): define networks used in our training. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + + super().__init__(opt, rank) + + if hasattr(opt, "fs_light"): + self.fs_light = opt.fs_light + + if opt.dataaug_diff_aug_policy != "": + self.diff_augment = DiffAugment( + opt.dataaug_diff_aug_policy, opt.dataaug_diff_aug_proba + ) + + self.objects_to_update = [] + + # Define loss functions + losses_G = ["G_tot"] + + self.loss_names_G = losses_G + + self.loss_functions_G = ["compute_G_loss_diffusion"] + self.forward_functions = ["forward_diffusion"] + + def init_semantic_cls(self, opt): + + # specify the training losses you want to print out. + # The training/test scripts will call + + super().init_semantic_cls(opt) + + def init_semantic_mask(self, opt): + + # specify the training losses you want to print out. + # The training/test scripts will call + + super().init_semantic_mask(opt) + + def forward_diffusion(self): + """Run forward pass; called by both functions and .""" + self.real_A_pool.query(self.real_A) + self.real_B_pool.query(self.real_B) + + if self.opt.output_display_G_attention_masks: + images, attentions, outputs = self.netG_A.get_attention_masks(self.real_A) + for i, cur_mask in enumerate(attentions): + setattr(self, "attention_" + str(i), cur_mask) + + for i, cur_output in enumerate(outputs): + setattr(self, "output_" + str(i), cur_output) + + for i, cur_image in enumerate(images): + setattr(self, "image_" + str(i), cur_image) + + if self.opt.data_online_context_pixels > 0: + + bs = self.get_current_batch_size() + self.mask_context = torch.ones( + [ + bs, + self.opt.model_input_nc, + self.opt.data_crop_size + self.margin, + self.opt.data_crop_size + self.margin, + ], + device=self.device, + ) + + self.mask_context[ + :, + :, + self.opt.data_online_context_pixels : -self.opt.data_online_context_pixels, + self.opt.data_online_context_pixels : -self.opt.data_online_context_pixels, + ] = torch.zeros( + [ + bs, + self.opt.model_input_nc, + self.opt.data_crop_size, + self.opt.data_crop_size, + ], + device=self.device, + ) + + self.mask_context_vis = torch.nn.functional.interpolate( + self.mask_context, size=self.real_A.shape[2:] + )[:, 0] + + if self.use_temporal: + self.compute_temporal_fake(objective_domain="B") + + if hasattr(self, "netG_B"): + self.compute_temporal_fake(objective_domain="A") + + def mse_loss(self, output, target): + return F.mse_loss(output, target) diff --git a/models/diffusion_networks.py b/models/diffusion_networks.py new file mode 100644 index 000000000..af42ea1ec --- /dev/null +++ b/models/diffusion_networks.py @@ -0,0 +1,62 @@ +from .modules.utils import get_norm_layer + + +from .modules.unet_generator_attn.diffusion_generator import DiffusionGenerator + + +def define_G( + model_input_nc, + model_output_nc, + G_netG, + G_nblocks, + data_crop_size, + G_norm, + G_unet_mha_n_timestep_train, + G_unet_mha_n_timestep_test, + G_ngf, + G_unet_mha_num_head_channels, + **unused_options +): + """Create a generator + + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + G_netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 + G_norm (str) -- the name of normalization layers used in the network: batch | instance | none + + Returns a generator + + Our current implementation provides two types of generators: + U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) + The original U-Net paper: https://arxiv.org/abs/1505.04597 + + Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) + Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. + We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). + + The generator has been initialized by . It uses RELU for non-linearity. + """ + net = None + norm_layer = get_norm_layer(norm_type=G_norm) + + if G_netG == "unet_mha": + net = DiffusionGenerator( + unet="unet_mha", + image_size=data_crop_size, + in_channel=model_input_nc * 2, + inner_channel=G_ngf, # e.g. 64 in palette repo + out_channel=model_output_nc, + res_blocks=G_nblocks, # 2 in palette repo + attn_res=[16], # e.g. + channel_mults=(1, 2, 4, 8), # e.g. + num_head_channels=G_unet_mha_num_head_channels, # e.g. 32 in palette repo + tanh=False, + n_timestep_train=G_unet_mha_n_timestep_train, + n_timestep_test=G_unet_mha_n_timestep_test, + ) + return net + else: + raise NotImplementedError( + "Generator model name [%s] is not recognized" % G_netG + ) diff --git a/models/gan_networks.py b/models/gan_networks.py index da1fb8ce1..5e11ac6e8 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -64,7 +64,6 @@ def define_G( G_config_segformer, G_stylegan2_num_downsampling, G_backward_compatibility_twice_resnet_blocks, - G_unet_mha_inner_channel, G_unet_mha_num_head_channels, **unused_options ): @@ -216,13 +215,15 @@ def define_G( net = UNet_mha( image_size=data_crop_size, in_channel=model_input_nc, - inner_channel=G_unet_mha_inner_channel, # e.g. 64 in palette repo + inner_channel=G_ngf, # e.g. 64 in palette repo out_channel=model_output_nc, res_blocks=G_nblocks, # 2 in palette repo attn_res=[16], # e.g. channel_mults=(1, 2, 4, 8), # e.g. num_head_channels=G_unet_mha_num_head_channels, # e.g. 32 in palette repo tanh=True, + n_timestep_train=0, # unused + n_timestep_test=0, # unused ) return net else: diff --git a/models/modules/unet_generator_attn/diffusion_generator.py b/models/modules/unet_generator_attn/diffusion_generator.py new file mode 100644 index 000000000..36311f7b1 --- /dev/null +++ b/models/modules/unet_generator_attn/diffusion_generator.py @@ -0,0 +1,165 @@ +import math +import torch +from inspect import isfunction +from functools import partial +import numpy as np +from tqdm import tqdm +from .unet_generator_attn import UNet +from torch import nn + + +class DiffusionGenerator(nn.Module): + def __init__( + self, + unet, + image_size, + in_channel, + inner_channel, + out_channel, + res_blocks, + attn_res, + tanh, + n_timestep_train, + n_timestep_test, + dropout=0, + channel_mults=(1, 2, 4, 8), + conv_resample=True, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=True, + resblock_updown=True, + use_new_attention_order=False, + ): + + super().__init__() + + if unet == "unet_mha": + self.denoise_fn = UNet( + image_size=image_size, + in_channel=in_channel, + inner_channel=inner_channel, + out_channel=out_channel, + res_blocks=res_blocks, + attn_res=attn_res, + tanh=tanh, + n_timestep_train=n_timestep_train, + n_timestep_test=n_timestep_test, + dropout=dropout, + channel_mults=channel_mults, + conv_resample=conv_resample, + use_checkpoint=use_checkpoint, + use_fp16=use_fp16, + num_heads=num_heads, + num_head_channels=num_head_channels, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + resblock_updown=resblock_updown, + use_new_attention_order=use_new_attention_order, + ) + + # Init noise schedule + self.denoise_fn.set_new_noise_schedule(phase="train") + self.denoise_fn.set_new_noise_schedule(phase="test") + + def restoration(self, y_cond, y_t=None, y_0=None, mask=None, sample_num=8): + phase = "test" + + b, *_ = y_cond.shape + + assert ( + self.denoise_fn.num_timesteps_test > sample_num + ), "num_timesteps must greater than sample_num" + sample_inter = self.denoise_fn.num_timesteps_test // sample_num + + y_t = self.default(y_t, lambda: torch.randn_like(y_cond)) + ret_arr = y_t + for i in tqdm( + reversed(range(0, self.denoise_fn.num_timesteps_test)), + desc="sampling loop time step", + total=self.denoise_fn.num_timesteps_test, + ): + t = torch.full((b,), i, device=y_cond.device, dtype=torch.long) + y_t = self.p_sample(y_t, t, y_cond=y_cond, phase=phase) + if mask is not None: + y_t = y_0 * (1.0 - mask) + mask * y_t + if i % sample_inter == 0: + ret_arr = torch.cat([ret_arr, y_t], dim=0) + return y_t, ret_arr + + def exists(self, x): + return x is not None + + def default(self, val, d): + if self.exists(val): + return val + return d() if isfunction(d) else d + + def extract(self, a, t, x_shape=(1, 1, 1, 1)): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + def p_mean_variance(self, y_t, t, phase, clip_denoised: bool, y_cond=None): + noise_level = self.extract( + getattr(self.denoise_fn, "gammas_" + phase), t, x_shape=(1, 1) + ).to(y_t.device) + y_0_hat = self.denoise_fn.predict_start_from_noise( + y_t, + t=t, + noise=self.denoise_fn(torch.cat([y_cond, y_t], dim=1), noise_level), + phase=phase, + ) + + if clip_denoised: + y_0_hat.clamp_(-1.0, 1.0) + + model_mean, posterior_log_variance = self.denoise_fn.q_posterior( + y_0_hat=y_0_hat, y_t=y_t, t=t, phase=phase + ) + return model_mean, posterior_log_variance + + def q_sample(self, y_0, sample_gammas, noise=None): + noise = self.default(noise, lambda: torch.randn_like(y_0)) + return sample_gammas.sqrt() * y_0 + (1 - sample_gammas).sqrt() * noise + + def p_sample(self, y_t, t, phase, clip_denoised=True, y_cond=None): + model_mean, model_log_variance = self.p_mean_variance( + y_t=y_t, t=t, clip_denoised=clip_denoised, y_cond=y_cond, phase=phase + ) + noise = torch.randn_like(y_t) if any(t > 0) else torch.zeros_like(y_t) + return model_mean + noise * (0.5 * model_log_variance).exp() + + def forward(self, y_0, y_cond, mask, noise): + b, *_ = y_0.shape + t = torch.randint( + 1, self.denoise_fn.num_timesteps_train, (b,), device=y_0.device + ).long() + + gammas = self.denoise_fn.gammas_train + + gamma_t1 = self.extract(gammas, t - 1, x_shape=(1, 1)) + sqrt_gamma_t2 = self.extract(gammas, t, x_shape=(1, 1)) + sample_gammas = (sqrt_gamma_t2 - gamma_t1) * torch.rand( + (b, 1), device=y_0.device + ) + gamma_t1 + sample_gammas = sample_gammas.view(b, -1) + + noise = self.default(noise, lambda: torch.randn_like(y_0)) + y_noisy = self.q_sample( + y_0=y_0, sample_gammas=sample_gammas.view(-1, 1, 1, 1), noise=noise + ) + + if mask is not None: + noise_hat = self.denoise_fn( + torch.cat([y_cond, y_noisy * mask + (1.0 - mask) * y_0], dim=1), + sample_gammas, + ) + else: + noise_hat = self.denoise_fn( + torch.cat([y_cond, y_noisy], dim=1), sample_gammas + ) + + return noise, noise_hat diff --git a/models/modules/unet_generator_attn/unet_generator_attn.py b/models/modules/unet_generator_attn/unet_generator_attn.py index d155b00bb..b51f14241 100644 --- a/models/modules/unet_generator_attn/unet_generator_attn.py +++ b/models/modules/unet_generator_attn/unet_generator_attn.py @@ -5,6 +5,8 @@ import torch.nn as nn import torch.nn.functional as F +import numpy as np + from .unet_attn_utils import ( checkpoint, zero_module, @@ -13,6 +15,8 @@ gamma_embedding, ) +from functools import partial + class SiLU(nn.Module): def forward(self, x): @@ -356,6 +360,8 @@ def __init__( res_blocks, attn_res, tanh, + n_timestep_train, + n_timestep_test, dropout=0, channel_mults=(1, 2, 4, 8), conv_resample=True, @@ -533,17 +539,33 @@ def __init__( zero_module(nn.Conv2d(input_ch, out_channel, 3, padding=1)), ) + self.beta_schedule = { + "train": { + "schedule": "linear", + "n_timestep": n_timestep_train, + "linear_start": 1e-6, + "linear_end": 0.01, + }, + "test": { + "schedule": "linear", + "n_timestep": n_timestep_test, + "linear_start": 1e-4, + "linear_end": 0.09, + }, + } + def compute_feats(self, input, gammas): + if gammas is None: b = input.shape[0] gammas = torch.ones((b,)).to(input.device) - hs = [] - gammas = gammas.view( - -1, - ) + gammas = gammas.view(-1) + emb = self.cond_embed(gamma_embedding(gammas, self.inner_channel)) + hs = [] + h = input.type(torch.float32) for module in self.input_blocks: h = module(h, emb) @@ -551,24 +573,20 @@ def compute_feats(self, input, gammas): h = self.middle_block(h, emb) outs, feats = h, hs - return outs, feats + return outs, feats, emb def forward(self, input, gammas=None): - if gammas is None: - b = input.shape[0] - gammas = torch.ones((b,)).to(input.device) - h, hs = self.compute_feats(input, gammas) - emb = self.cond_embed(gamma_embedding(gammas, self.inner_channel)) - for module in self.output_blocks: + h, hs, emb = self.compute_feats(input, gammas=gammas) + + for i, module in enumerate(self.output_blocks): h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb) h = h.type(input.dtype) return self.out(h) def get_feats(self, input, extract_layer_ids): - _, hs = self.compute_feats(input, gammas=None) - + _, hs, _ = self.compute_feats(input, gammas=None) feats = [] for i, feat in enumerate(hs): @@ -577,6 +595,105 @@ def get_feats(self, input, extract_layer_ids): return feats + def set_new_noise_schedule(self, phase): + to_torch = partial(torch.tensor, dtype=torch.float32) + betas = make_beta_schedule(**self.beta_schedule[phase]) + betas = ( + betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas + ) + alphas = 1.0 - betas + + (timesteps,) = betas.shape + setattr(self, "num_timesteps_" + phase, int(timesteps)) + + gammas = np.cumprod(alphas, axis=0) + gammas_prev = np.append(1.0, gammas[:-1]) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("gammas_" + phase, to_torch(gammas)) + name = "gammas_" + phase + self.register_buffer( + "sqrt_recip_gammas_" + phase, to_torch(np.sqrt(1.0 / gammas)) + ) + self.register_buffer( + "sqrt_recipm1_gammas_" + phase, to_torch(np.sqrt(1.0 / gammas - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1.0 - gammas_prev) / (1.0 - gammas) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped_" + phase, + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1_" + phase, + to_torch(betas * np.sqrt(gammas_prev) / (1.0 - gammas)), + ) + self.register_buffer( + "posterior_mean_coef2_" + phase, + to_torch((1.0 - gammas_prev) * np.sqrt(alphas) / (1.0 - gammas)), + ) + + def predict_start_from_noise(self, y_t, t, noise, phase): + return ( + self.extract(getattr(self, "sqrt_recip_gammas_" + phase), t, y_t.shape) + * y_t + - self.extract(getattr(self, "sqrt_recipm1_gammas_" + phase), t, y_t.shape) + * noise + ) + + def q_posterior(self, y_0_hat, y_t, t, phase): + posterior_mean = ( + self.extract(getattr(self, "posterior_mean_coef1_" + phase), t, y_t.shape) + * y_0_hat + + self.extract(getattr(self, "posterior_mean_coef2_" + phase), t, y_t.shape) + * y_t + ) + posterior_log_variance_clipped = self.extract( + getattr(self, "posterior_log_variance_clipped_" + phase), t, y_t.shape + ) + return posterior_mean, posterior_log_variance_clipped + + def extract(self, a, t, x_shape=(1, 1, 1, 1)): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-6, linear_end=1e-2, cosine_s=8e-3 +): + if schedule == "quad": + betas = ( + np.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=np.float64 + ) + ** 2 + ) + elif schedule == "linear": + betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) + elif schedule == "warmup10": + betas = _warmup_beta(linear_start, linear_end, n_timestep, 0.1) + elif schedule == "warmup50": + betas = _warmup_beta(linear_start, linear_end, n_timestep, 0.5) + elif schedule == "const": + betas = linear_end * np.ones(n_timestep, dtype=np.float64) + elif schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace(n_timestep, 1, n_timestep, dtype=np.float64) + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * math.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = betas.clamp(max=0.999) + else: + raise NotImplementedError(schedule) + return betas + if __name__ == "__main__": b, c, h, w = 3, 6, 64, 64 diff --git a/models/palette_model.py b/models/palette_model.py new file mode 100644 index 000000000..4dcb0d9c2 --- /dev/null +++ b/models/palette_model.py @@ -0,0 +1,134 @@ +import torch +import tqdm + +from .base_diffusion_model import BaseDiffusionModel +from util.network_group import NetworkGroup +from util.iter_calculator import IterCalculator +from . import diffusion_networks + +import copy + +import warnings + + +class PaletteModel(BaseDiffusionModel): + def __init__(self, opt, rank): + super().__init__(opt, rank) + + # Visuals + self.visual_names.append(["gt_image", "cond_image", "mask", "output"]) + + if opt.G_nblocks == 9: + warnings.warn( + f"G_nblocks default value {opt.G_nblocks} is too high for palette model, 2 will be used instead." + ) + opt.G_nblocks = 2 + + # Define networks + self.netG_A = diffusion_networks.define_G(**vars(opt)) + + self.model_names = ["G_A"] + + self.model_names_export = ["G_A"] + + # Define optimizer + self.optimizer_G = opt.optim( + opt, + self.netG_A.parameters(), + lr=opt.train_G_lr, + betas=(opt.train_beta1, opt.train_beta2), + ) + + self.optimizers.append(self.optimizer_G) + + # Define loss functions + + losses_G = ["G_tot"] + + self.loss_names_G = losses_G + self.loss_names = self.loss_names_G + + # Make group + self.networks_groups = [] + + losses_backward = ["loss_G_tot"] + + self.group_G = NetworkGroup( + networks_to_optimize=["G_A"], + forward_functions=[], + backward_functions=["compute_palette_loss"], + loss_names_list=["loss_names_G"], + optimizer=["optimizer_G"], + loss_backward=losses_backward, + networks_to_ema=["G_A"], + ) + self.networks_groups.append(self.group_G) + + losses_G = [] + + self.loss_names_G += losses_G + + self.loss_names = self.loss_names_G.copy() + + # Itercalculator + if self.opt.train_iter_size > 1: + + self.iter_calculator = IterCalculator(self.loss_names) + for i, cur_loss in enumerate(self.loss_names): + self.loss_names[i] = cur_loss + "_avg" + setattr(self, "loss_" + self.loss_names[i], 0) + + # homemade + self.loss_fn = self.mse_loss + self.sample_num = 2 # temp + + def set_input(self, data): + """must use set_device in tensor""" + + self.cond_image = data["A"].to(self.device) + self.gt_image = data["B"].to(self.device) + self.mask = data["B_label_mask"].to(self.device) + self.batch_size = self.cond_image.shape[0] + + self.real_A = self.cond_image + self.real_B = self.gt_image + + def compute_palette_loss(self): + y_0 = self.gt_image + y_cond = self.cond_image + mask = self.mask + noise = None + + noise, noise_hat = self.netG_A(y_0, y_cond, mask, noise) + + if mask is not None: + loss = self.loss_fn(mask * noise, mask * noise_hat) + else: + loss = self.loss_fn(noise, noise_hat) + + self.loss_G_tot = loss + + def inference(self): + if hasattr(self.netG_A, "module"): + netG = self.netG_A.module + else: + netG = self.netG_A + if True or self.task in ["inpainting", "uncropping"]: + self.output, self.visuals = netG.restoration( + self.cond_image, + y_t=self.cond_image, + y_0=self.gt_image, + mask=self.mask, + sample_num=self.sample_num, + ) + else: + self.output, self.visuals = self.restoration( + self.cond_image, sample_num=self.sample_num + ) + + self.fake_B = self.visuals[-1:] + + def compute_visuals(self): + super().compute_visuals() + with torch.no_grad(): + self.inference() diff --git a/options/base_options.py b/options/base_options.py index 5be806112..7ece5cd8b 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -114,10 +114,7 @@ def initialize(self, parser): "--model_type", type=str, default="cut", - choices=[ - "cut", - "cycle_gan", - ], + choices=["cut", "cycle_gan", "palette"], help="chooses which model to use.", ) parser.add_argument( @@ -248,8 +245,20 @@ def initialize(self, parser): help="specify multimodal latent vector encoder", ) - parser.add_argument("--G_unet_mha_inner_channel", default=64, type=int) parser.add_argument("--G_unet_mha_num_head_channels", default=32, type=int) + parser.add_argument( + "--G_unet_mha_n_timestep_train", + type=int, + default=2000, + help="Number of timesteps used for UNET mha training.", + ) + + parser.add_argument( + "--G_unet_mha_n_timestep_test", + type=int, + default=2000, + help="Number of timesteps used for UNET mha inference (test time).", + ) # discriminator parser.add_argument( diff --git a/scripts/gen_single_image.py b/scripts/gen_single_image.py index b3fcd5fc7..bc14311db 100644 --- a/scripts/gen_single_image.py +++ b/scripts/gen_single_image.py @@ -3,7 +3,7 @@ import json sys.path.append("../") -from models import networks +from models import networks, networks_diffusion from options.train_options import TrainOptions import cv2 import torch @@ -30,7 +30,7 @@ def load_model(modelpath, model_in_file, device): opt.model_input_nc += opt.train_mm_nz opt.jg_dir = "../" - model = networks.define_G(**vars(opt)) + model = networks_diffusion.define_G(**vars(opt)) model.eval() model.load_state_dict(torch.load(modelpath + "/" + model_in_file)) diff --git a/scripts/gen_single_image_diffusion.py b/scripts/gen_single_image_diffusion.py new file mode 100644 index 000000000..f8fa91b78 --- /dev/null +++ b/scripts/gen_single_image_diffusion.py @@ -0,0 +1,159 @@ +import sys +import os +import json + +sys.path.append("../") +from models import diffusion_networks +from options.train_options import TrainOptions +import cv2 +import torch +from torchvision import transforms +from torchvision.utils import save_image +import numpy as np +import argparse +from data.online_creation import fill_mask_with_random, fill_mask_with_color + + +def load_model(modelpath, model_in_file, device): + train_json_path = modelpath + "/train_config.json" + with open(train_json_path, "r") as jsonf: + train_json = json.load(jsonf) + opt = TrainOptions().parse_json(train_json) + if opt.model_multimodal: + opt.model_input_nc += opt.train_mm_nz + opt.jg_dir = "../" + + model = diffusion_networks.define_G(**vars(opt)) + model.eval() + model.load_state_dict(torch.load(modelpath + "/" + model_in_file)) + + model = model.to(device) + return model, opt + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--model-in-file", help="file path to generator model (.pth file)", required=True +) + +parser.add_argument("--img-size", default=256, type=int, help="square image size") +parser.add_argument("--img-in", help="image to transform", required=True) +parser.add_argument( + "--mask-in", help="mask used for image transformation", required=True +) +parser.add_argument("--img-out", help="transformed image", required=True) +parser.add_argument("--cpu", action="store_true", help="whether to use CPU") +parser.add_argument("--gpuid", type=int, default=0, help="which GPU to use") +args = parser.parse_args() + +# loading model +modelpath = args.model_in_file.replace(os.path.basename(args.model_in_file), "") +print("modelpath=", modelpath) + +if not args.cpu: + device = torch.device("cuda:" + str(args.gpuid)) +model, opt = load_model(modelpath, os.path.basename(args.model_in_file), device) + +# reading image +img = cv2.imread(args.img_in) +img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) +mask = cv2.imread(args.mask_in, 0) + +# preprocessing +totensor = transforms.ToTensor() +resize = transforms.Resize(args.img_size) +tranlist = [ + totensor, + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + resize, +] +tran = transforms.Compose(tranlist) +img_tensor = tran(img).clone().detach() +mask = torch.from_numpy(np.array(mask, dtype=np.int64)).unsqueeze(0) +mask = resize(mask).clone().detach() + +if not args.cpu: + img_tensor = img_tensor.to(device).clone().detach() + mask = mask.to(device).clone().detach() + +if opt.data_online_creation_rand_mask_A: + cond_image = fill_mask_with_random( + img_tensor.clone().detach(), mask.clone().detach(), -1 + ) +elif opt.data_online_creation_color_mask_A: + cond_image = fill_mask_with_color( + img_tensor.clone().detach(), mask.clone().detach(), {} + ) + +# run through model +cond_image, img_tensor, mask = ( + cond_image.unsqueeze(0).clone().detach(), + img_tensor.unsqueeze(0).clone().detach(), + mask.unsqueeze(0).clone().detach(), +) + + +with torch.no_grad(): + out_tensor, visu = model.restoration( + cond_image.clone().detach(), + y_t=cond_image.clone().detach(), + y_0=img_tensor.clone().detach(), + mask=mask.clone().detach(), + sample_num=2, + ) + + +print("outtensor", out_tensor.shape, "visu", visu.shape) + +temp = img_tensor - out_tensor +print(temp.mean(), temp.min(), temp.max()) +print(visu.shape) + +# out_tensor = visu[-1:] + +# post-processing +out_img = out_tensor.detach().data.cpu().float().numpy()[0] +img_np = img_tensor.detach().data.cpu().float().numpy()[0] +cond_image = cond_image.detach().data.cpu().float().numpy()[0] +# cond_image = torch.randn_like(cond_image) +visu = visu.detach().data.cpu().float().numpy() +visu1 = visu[1] +visu2 = visu[2] +visu0 = visu[0] + +temp = out_img - img_np +print("np", temp.mean(), temp.min(), temp.max()) + +out_img = (np.transpose(out_img, (1, 2, 0)) + 1) / 2.0 * 255.0 +img_np = (np.transpose(img_np, (1, 2, 0)) + 1) / 2.0 * 255.0 +cond_image = (np.transpose(cond_image, (1, 2, 0)) + 1) / 2.0 * 255.0 +visu0 = (np.transpose(visu0, (1, 2, 0)) + 1) / 2.0 * 255.0 +visu1 = (np.transpose(visu1, (1, 2, 0)) + 1) / 2.0 * 255.0 +visu2 = (np.transpose(visu2, (1, 2, 0)) + 1) / 2.0 * 255.0 +print(out_img) +print(img_np) + +temp = out_img - img_np +print("np", temp.mean(), temp.min(), temp.max()) + +out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR) +cv2.imwrite(args.img_out, out_img) + +img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) +cv2.imwrite("/data1/pnsuau/checkpoints/test_palette_4/img_np.jpg", img_np) + +cond_image = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR) +cv2.imwrite("/data1/pnsuau/checkpoints/test_palette_4/cond_image.jpg", cond_image) + + +visu0 = cv2.cvtColor(visu0, cv2.COLOR_RGB2BGR) +cv2.imwrite("/data1/pnsuau/checkpoints/test_palette_4/visu0.jpg", visu0) + +visu1 = cv2.cvtColor(visu1, cv2.COLOR_RGB2BGR) +cv2.imwrite("/data1/pnsuau/checkpoints/test_palette_4/visu1.jpg", visu1) + +visu2 = cv2.cvtColor(visu2, cv2.COLOR_RGB2BGR) +cv2.imwrite("/data1/pnsuau/checkpoints/test_palette_4/visu2.jpg", visu2) + + +print("Successfully generated image ", args.img_out) diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index 0a12a984e..f74405cb9 100644 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -32,9 +32,6 @@ rm $ZIP_FILE python3 -m pytest -p no:cacheprovider -s "${current_dir}/../tests/test_run_nosemantic.py" --dataroot "$TARGET_NOSEM_DIR" OUT=$? -echo "Deleting target dir $DIR" -rm -rf $DIR/* - if [ $OUT != 0 ]; then exit 1 fi @@ -60,8 +57,13 @@ fi python3 -m pytest -p no:cacheprovider -s "${current_dir}/../tests/test_run_semantic_self_supervised.py" --dataroot "$TARGET_MASK_SEM_DIR" OUT=$? -echo "Deleting target dir $DIR" -rm -rf $DIR/* +if [ $OUT != 0 ]; then + exit 1 +fi + +###### diffusion process test +python3 -m pytest -p no:cacheprovider -s "${current_dir}/../tests/test_run_diffusion.py" --dataroot "$TARGET_MASK_SEM_DIR" +OUT=$? if [ $OUT != 0 ]; then exit 1 diff --git a/tests/test_run_diffusion.py b/tests/test_run_diffusion.py new file mode 100644 index 000000000..1ea1d80fa --- /dev/null +++ b/tests/test_run_diffusion.py @@ -0,0 +1,51 @@ +import pytest +import torch.multiprocessing as mp +import sys + +sys.path.append(sys.path[0] + "/..") +import train +from options.train_options import TrainOptions +from data import create_dataset + +json_like_dict = { + "name": "joligan_utest", + "output_display_env": "joligan_utest", + "output_display_id": 0, + "gpu_ids": "0", + "data_dataset_mode": "self_supervised_labeled_mask", + "data_load_size": 128, + "data_crop_size": 128, + "train_n_epochs": 1, + "train_n_epochs_decay": 0, + "data_max_dataset_size": 10, + "data_relative_paths": True, + "train_G_ema": True, + "dataaug_no_rotate": True, + "G_unet_mha_inner_channel": 32, + "G_unet_mha_num_head_channels": 16, + "G_nblocks": 1, + "G_padding_type": "reflect", + "data_online_creation_rand_mask_A": True, +} + + +models_diffusion = ["palette"] + +G_netG = ["unet_mha"] + +D_proj_network_type = ["efficientnet", "vitsmall"] + +f_s_net = ["unet", "segformer"] + + +def test_semantic_mask(dataroot): + json_like_dict["dataroot"] = dataroot + json_like_dict["checkpoints_dir"] = "/".join(dataroot.split("/")[:-1]) + for model in models_diffusion: + json_like_dict["model_type"] = model + json_like_dict["name"] += "_" + model + json_like_dict_c = json_like_dict.copy() + for Gtype in G_netG: + json_like_dict_c["G_netG"] = Gtype + opt = TrainOptions().parse_json(json_like_dict_c) + train.launch_training(opt)