From b3a7ab11a2b1eded4acd37e82b8f7a158da00314 Mon Sep 17 00:00:00 2001 From: Tan Nian Wei Date: Sat, 11 Dec 2021 12:41:06 +0800 Subject: [PATCH] feat: added cutout augs to vqgan-clip --- app.py | 15 +++++++++++++++ logic.py | 27 +++++++++++++++++++++++++-- vqgan_utils.py | 16 ++++++++++++++-- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/app.py b/app.py index 8e6ded7..8c1d803 100644 --- a/app.py +++ b/app.py @@ -44,6 +44,7 @@ def generate_image( mse_weight_decay: float = 0, mse_weight_decay_steps: int = 0, tv_loss_weight: float = 1e-3, + use_cutout_augmentations: bool = True, ) -> None: ### Init ------------------------------------------------------------------- @@ -61,6 +62,7 @@ def generate_image( mse_weight_decay=mse_weight_decay, mse_weight_decay_steps=mse_weight_decay_steps, tv_loss_weight=tv_loss_weight, + use_cutout_augmentations=use_cutout_augmentations, ) ### Load model ------------------------------------------------------------- @@ -189,6 +191,9 @@ def generate_image( "tv_loss_weight": tv_loss_weight, } + if use_cutout_augmentations: + details["use_cutout_augmentations"] = True + if "git" in sys.modules: try: repo = git.Repo(search_parent_directories=True) @@ -257,6 +262,9 @@ def generate_image( "tv_loss_weight": tv_loss_weight, } + if use_cutout_augmentations: + details["use_cutout_augmentations"] = True + if "git" in sys.modules: try: repo = git.Repo(search_parent_directories=True) @@ -453,6 +461,12 @@ def generate_image( else: tv_loss_weight = 0 + use_cutout_augmentations = st.sidebar.checkbox( + "Use cutout augmentations", + value=True, + help="Increases image quality, uses additional 1-2 GiB of GPU memory", + ) + submitted = st.form_submit_button("Run!") # End of form @@ -519,6 +533,7 @@ def generate_image( mse_weight=mse_weight, mse_weight_decay=mse_weight_decay, mse_weight_decay_steps=mse_weight_decay_steps, + use_cutout_augmentations=use_cutout_augmentations, ) vid_display_slot.video("temp.mp4") # debug_slot.write(st.session_state) # DEBUG diff --git a/logic.py b/logic.py index 562fee1..7854d66 100644 --- a/logic.py +++ b/logic.py @@ -18,6 +18,7 @@ from torch.nn import functional as F from torch import optim from torchvision import transforms +import kornia.augmentation as K class Run: @@ -70,7 +71,8 @@ def __init__( mse_weight=0.5, mse_weight_decay=0.1, mse_weight_decay_steps=50, - tv_loss_weight=1e-3 + tv_loss_weight=1e-3, + use_cutout_augmentations: bool = True # use_augs: bool = True, # noise_fac: float = 0.1, # use_noise: Optional[float] = None, @@ -130,6 +132,8 @@ def __init__( self.mse_weight_decay = mse_weight_decay self.mse_weight_decay_steps = mse_weight_decay_steps + self.use_cutout_augmentations = use_cutout_augmentations + # For TV loss self.tv_loss_weight = tv_loss_weight @@ -159,9 +163,28 @@ def model_init(self, init_image: Image.Image = None) -> None: cut_size = self.perceptor.visual.input_resolution e_dim = self.model.quantize.e_dim f = 2 ** (self.model.decoder.num_resolutions - 1) + + if self.use_cutout_augmentations: + noise_fac = 0.1 + augs = nn.Sequential( + K.RandomHorizontalFlip(p=0.5), + K.RandomSharpness(0.3, p=0.4), + K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), + K.RandomPerspective(0.2, p=0.4), + K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), + ) + else: + noise_fac = None + augs = None + self.make_cutouts = MakeCutouts( - cut_size, self.args.cutn, cut_pow=self.args.cut_pow + cut_size, + self.args.cutn, + cut_pow=self.args.cut_pow, + noise_fac=noise_fac, + augs=augs, ) + n_toks = self.model.quantize.n_e toksX, toksY = self.args.size[0] // f, self.args.size[1] // f sideX, sideY = toksX * f, toksY * f diff --git a/vqgan_utils.py b/vqgan_utils.py index e1b79a3..eb62d7b 100644 --- a/vqgan_utils.py +++ b/vqgan_utils.py @@ -143,11 +143,13 @@ def parse_prompt(prompt): class MakeCutouts(nn.Module): - def __init__(self, cut_size, cutn, cut_pow=1.0): + def __init__(self, cut_size, cutn, cut_pow=1.0, noise_fac=None, augs=None): super().__init__() self.cut_size = cut_size self.cutn = cutn self.cut_pow = cut_pow + self.noise_fac = noise_fac + self.augs = augs def forward(self, input): sideY, sideX = input.shape[2:4] @@ -162,7 +164,17 @@ def forward(self, input): offsety = torch.randint(0, sideY - size + 1, ()) cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) - return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1) + + if self.augs: + batch = self.augs(torch.cat(cutouts, dim=0)) + else: + batch = torch.cat(cutouts, dim=0) + + if self.noise_fac: + facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) + batch = batch + facs * torch.randn_like(batch) + + return clamp_with_grad(batch, 0, 1) def load_vqgan_model(config_path, checkpoint_path):