Skip to content

Commit

Permalink
feat: added cutout augs to vqgan-clip
Browse files Browse the repository at this point in the history
  • Loading branch information
Tan Nian Wei committed Dec 11, 2021
1 parent 78703d1 commit b3a7ab1
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 4 deletions.
15 changes: 15 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def generate_image(
mse_weight_decay: float = 0,
mse_weight_decay_steps: int = 0,
tv_loss_weight: float = 1e-3,
use_cutout_augmentations: bool = True,
) -> None:

### Init -------------------------------------------------------------------
Expand All @@ -61,6 +62,7 @@ def generate_image(
mse_weight_decay=mse_weight_decay,
mse_weight_decay_steps=mse_weight_decay_steps,
tv_loss_weight=tv_loss_weight,
use_cutout_augmentations=use_cutout_augmentations,
)

### Load model -------------------------------------------------------------
Expand Down Expand Up @@ -189,6 +191,9 @@ def generate_image(
"tv_loss_weight": tv_loss_weight,
}

if use_cutout_augmentations:
details["use_cutout_augmentations"] = True

if "git" in sys.modules:
try:
repo = git.Repo(search_parent_directories=True)
Expand Down Expand Up @@ -257,6 +262,9 @@ def generate_image(
"tv_loss_weight": tv_loss_weight,
}

if use_cutout_augmentations:
details["use_cutout_augmentations"] = True

if "git" in sys.modules:
try:
repo = git.Repo(search_parent_directories=True)
Expand Down Expand Up @@ -453,6 +461,12 @@ def generate_image(
else:
tv_loss_weight = 0

use_cutout_augmentations = st.sidebar.checkbox(
"Use cutout augmentations",
value=True,
help="Increases image quality, uses additional 1-2 GiB of GPU memory",
)

submitted = st.form_submit_button("Run!")
# End of form

Expand Down Expand Up @@ -519,6 +533,7 @@ def generate_image(
mse_weight=mse_weight,
mse_weight_decay=mse_weight_decay,
mse_weight_decay_steps=mse_weight_decay_steps,
use_cutout_augmentations=use_cutout_augmentations,
)
vid_display_slot.video("temp.mp4")
# debug_slot.write(st.session_state) # DEBUG
27 changes: 25 additions & 2 deletions logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.nn import functional as F
from torch import optim
from torchvision import transforms
import kornia.augmentation as K


class Run:
Expand Down Expand Up @@ -70,7 +71,8 @@ def __init__(
mse_weight=0.5,
mse_weight_decay=0.1,
mse_weight_decay_steps=50,
tv_loss_weight=1e-3
tv_loss_weight=1e-3,
use_cutout_augmentations: bool = True
# use_augs: bool = True,
# noise_fac: float = 0.1,
# use_noise: Optional[float] = None,
Expand Down Expand Up @@ -130,6 +132,8 @@ def __init__(
self.mse_weight_decay = mse_weight_decay
self.mse_weight_decay_steps = mse_weight_decay_steps

self.use_cutout_augmentations = use_cutout_augmentations

# For TV loss
self.tv_loss_weight = tv_loss_weight

Expand Down Expand Up @@ -159,9 +163,28 @@ def model_init(self, init_image: Image.Image = None) -> None:
cut_size = self.perceptor.visual.input_resolution
e_dim = self.model.quantize.e_dim
f = 2 ** (self.model.decoder.num_resolutions - 1)

if self.use_cutout_augmentations:
noise_fac = 0.1
augs = nn.Sequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomSharpness(0.3, p=0.4),
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
K.RandomPerspective(0.2, p=0.4),
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
)
else:
noise_fac = None
augs = None

self.make_cutouts = MakeCutouts(
cut_size, self.args.cutn, cut_pow=self.args.cut_pow
cut_size,
self.args.cutn,
cut_pow=self.args.cut_pow,
noise_fac=noise_fac,
augs=augs,
)

n_toks = self.model.quantize.n_e
toksX, toksY = self.args.size[0] // f, self.args.size[1] // f
sideX, sideY = toksX * f, toksY * f
Expand Down
16 changes: 14 additions & 2 deletions vqgan_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,13 @@ def parse_prompt(prompt):


class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.0):
def __init__(self, cut_size, cutn, cut_pow=1.0, noise_fac=None, augs=None):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.cut_pow = cut_pow
self.noise_fac = noise_fac
self.augs = augs

def forward(self, input):
sideY, sideX = input.shape[2:4]
Expand All @@ -162,7 +164,17 @@ def forward(self, input):
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1)

if self.augs:
batch = self.augs(torch.cat(cutouts, dim=0))
else:
batch = torch.cat(cutouts, dim=0)

if self.noise_fac:
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
batch = batch + facs * torch.randn_like(batch)

return clamp_with_grad(batch, 0, 1)


def load_vqgan_model(config_path, checkpoint_path):
Expand Down

0 comments on commit b3a7ab1

Please sign in to comment.