From 7de1f6caab91f5a5fbfbf890cb6cfe28d93cf9c8 Mon Sep 17 00:00:00 2001 From: Tan Nian Wei Date: Tue, 12 Oct 2021 23:22:40 +0800 Subject: [PATCH 1/9] feat: minimal working example --- mse_vqgan_utils.py | 172 +++++++++++++++ mseapp.py | 419 +++++++++++++++++++++++++++++++++++++ mselogic.py | 508 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 1099 insertions(+) create mode 100644 mse_vqgan_utils.py create mode 100644 mseapp.py create mode 100644 mselogic.py diff --git a/mse_vqgan_utils.py b/mse_vqgan_utils.py new file mode 100644 index 0000000..f752154 --- /dev/null +++ b/mse_vqgan_utils.py @@ -0,0 +1,172 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +import io +from omegaconf import OmegaConf +from taming.models import cond_transformer, vqgan +from PIL import Image +from torchvision.transforms import functional as TF +import sys +import kornia.augmentation as K + +from vqgan_utils import resample, clamp_with_grad, replace_grad, vector_quantize + +sys.path.append("./taming-transformers") + + +def noise_gen(shape): + n, c, h, w = shape + noise = torch.zeros([n, c, 1, 1]) + for i in reversed(range(5)): + h_cur, w_cur = h // 2 ** i, w // 2 ** i + noise = F.interpolate( + noise, (h_cur, w_cur), mode="bicubic", align_corners=False + ) + noise += torch.randn([n, c, h_cur, w_cur]) / 5 + return noise + + +def one_sided_clip_loss(input, target, labels=None, logit_scale=100): + input_normed = F.normalize(input, dim=-1) + target_normed = F.normalize(target, dim=-1) + logits = input_normed @ target_normed.T * logit_scale + if labels is None: + labels = torch.arange(len(input), device=logits.device) + return F.cross_entropy(logits, labels) + + +class MSEMakeCutouts(nn.Module): + def __init__(self, cut_size, cutn, cut_pow=1.0, augs=None, noise_fac=0.1): + super().__init__() + self.cut_size = cut_size + self.cutn = cutn + self.cut_pow = cut_pow + self.augs = augs + self.noise_fac = noise_fac + + self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) + self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) + + def set_cut_pow(self, cut_pow): + self.cut_pow = cut_pow + + def forward(self, input): + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + + min_size_width = min(sideX, sideY) + + for ii in range(self.cutn): + + size = int( + torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size + ) + + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] + cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) + + cutouts = torch.cat(cutouts, dim=0) + + if self.augs is not None: + cutouts = self.augs(cutouts) + + if self.noise_fac: + facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_( + 0, self.noise_fac + ) + cutouts = cutouts + facs * torch.randn_like(cutouts) + + return clamp_with_grad(cutouts, 0, 1) + + +class TVLoss(nn.Module): + def forward(self, input): + input = F.pad(input, (0, 1, 0, 1), "replicate") + x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] + y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] + diff = x_diff ** 2 + y_diff ** 2 + 1e-8 + return diff.mean(dim=1).sqrt().mean() + + +class GaussianBlur2d(nn.Module): + def __init__(self, sigma, window=0, mode="reflect", value=0): + super().__init__() + self.mode = mode + self.value = value + if not window: + window = max(math.ceil((sigma * 6 + 1) / 2) * 2 - 1, 3) + if sigma: + kernel = torch.exp( + -((torch.arange(window) - window // 2) ** 2) / 2 / sigma ** 2 + ) + kernel /= kernel.sum() + else: + kernel = torch.ones([1]) + self.register_buffer("kernel", kernel) + + def forward(self, input): + n, c, h, w = input.shape + input = input.view([n * c, 1, h, w]) + start_pad = (self.kernel.shape[0] - 1) // 2 + end_pad = self.kernel.shape[0] // 2 + input = F.pad( + input, (start_pad, end_pad, start_pad, end_pad), self.mode, self.value + ) + input = F.conv2d(input, self.kernel[None, None, None, :]) + input = F.conv2d(input, self.kernel[None, None, :, None]) + return input.view([n, c, h, w]) + + +class EMATensor(nn.Module): + """implmeneted by Katherine Crowson""" + + def __init__(self, tensor, decay): + super().__init__() + self.tensor = nn.Parameter(tensor) + self.register_buffer("biased", torch.zeros_like(tensor)) + self.register_buffer("average", torch.zeros_like(tensor)) + self.decay = decay + self.register_buffer("accum", torch.tensor(1.0)) + self.update() + + @torch.no_grad() + def update(self): + if not self.training: + raise RuntimeError("update() should only be called during training") + + self.accum *= self.decay + self.biased.mul_(self.decay) + self.biased.add_((1 - self.decay) * self.tensor) + self.average.copy_(self.biased) + self.average.div_(1 - self.accum) + + def forward(self): + if self.training: + return self.tensor + return self.average + + +def synth_mse(model, z, is_openimages_f16_8192: bool = False, quantize: bool = True): + # z_mask not defined in notebook, appears to be unused + # if constraint_regions: + # z = replace_grad(z, z * z_mask) + + if quantize: + if is_openimages_f16_8192: + z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim( + 3, 1 + ) + else: + z_q = vector_quantize( + z.movedim(1, 3), model.quantize.embedding.weight + ).movedim(3, 1) + + else: + z_q = z.model + + return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1) diff --git a/mseapp.py b/mseapp.py new file mode 100644 index 0000000..6ce5eb7 --- /dev/null +++ b/mseapp.py @@ -0,0 +1,419 @@ +""" +This script is organized like so: ++ `if __name__ == "__main__" sets up the Streamlit UI elements ++ `generate_image` houses interactions between UI and the CLIP image +generation models ++ Core model code is abstracted in `logic.py` and imported in `generate_image` +""" +import streamlit as st +from pathlib import Path +import sys +import datetime +import shutil +import json +import os +import base64 + +sys.path.append("./taming-transformers") + +from PIL import Image +from typing import Optional, List +from omegaconf import OmegaConf +import imageio +import numpy as np +from mselogic import MSEVQGANCLIPRun + + +def generate_image( + text_input: str = "the first day of the waters", + vqgan_ckpt: str = "vqgan_imagenet_f16_16384", + num_steps: int = 300, + image_x: int = 300, + image_y: int = 300, + init_image: Optional[Image.Image] = None, + image_prompts: List[Image.Image] = [], + continue_prev_run: bool = False, + seed: Optional[int] = None, +) -> None: + + ### Init ------------------------------------------------------------------- + run = MSEVQGANCLIPRun( + text_input=text_input, + vqgan_ckpt=vqgan_ckpt, + num_steps=num_steps, + image_x=image_x, + image_y=image_y, + seed=seed, + init_image=init_image, + image_prompts=image_prompts, + continue_prev_run=continue_prev_run, + use_augs=True, + noise_fac=0.1, + use_noise=None, + mse_withzeros=True, + mse_decay_rate=50, + mse_epoches=5, + ) + + ### Load model ------------------------------------------------------------- + + if continue_prev_run is True: + run.load_model( + prev_model=st.session_state["model"], + prev_perceptor=st.session_state["perceptor"], + ) + prev_run_id = st.session_state["run_id"] + + else: + # Remove the cache first! CUDA out of memory + if "model" in st.session_state: + del st.session_state["model"] + + if "perceptor" in st.session_state: + del st.session_state["perceptor"] + + st.session_state["model"], st.session_state["perceptor"] = run.load_model() + prev_run_id = None + + # Generate random run ID + # Used to link runs linked w/ continue_prev_run + # ref: https://stackoverflow.com/a/42703382/13095028 + # Use URL and filesystem safe version since we're using this as a folder name + run_id = st.session_state["run_id"] = base64.urlsafe_b64encode( + os.urandom(6) + ).decode("ascii") + + run_start_dt = datetime.datetime.now() + + ### Model init ------------------------------------------------------------- + if continue_prev_run is True: + run.model_init(init_image=st.session_state["prev_im"]) + elif init_image is not None: + run.model_init(init_image=init_image) + else: + run.model_init() + + ### Iterate ---------------------------------------------------------------- + step_counter = 0 + frames = [] + + try: + # Try block catches st.script_runner.StopExecution, no need of a dedicated stop button + # Reason is st.form is meant to be self-contained either within sidebar, or in main body + # The way the form is implemented in this app splits the form across both regions + # This is intended to prevent the model settings from crowding the main body + # However, touching any button resets the app state, making it impossible to + # implement a stop button that can still dump output + # Thankfully there's a built-in stop button :) + while True: + # While loop to accomodate running predetermined steps or running indefinitely + status_text.text(f"Running step {step_counter}") + + _, im = run.iterate() + + if num_steps > 0: # skip when num_steps = -1 + step_progress_bar.progress((step_counter + 1) / num_steps) + else: + step_progress_bar.progress(100) + + # At every step, display and save image + im_display_slot.image(im, caption="Output image", output_format="PNG") + st.session_state["prev_im"] = im + + # ref: https://stackoverflow.com/a/33117447/13095028 + # im_byte_arr = io.BytesIO() + # im.save(im_byte_arr, format="JPEG") + # frames.append(im_byte_arr.getvalue()) # read() + frames.append(np.asarray(im)) + + step_counter += 1 + + if (step_counter == num_steps) and num_steps > 0: + break + + # Stitch into video using imageio + writer = imageio.get_writer("temp.mp4", fps=24) + for frame in frames: + writer.append_data(frame) + writer.close() + + # Save to output folder if run completed + runoutputdir = outputdir / ( + run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id + ) + runoutputdir.mkdir() + + # Save final image + im.save(runoutputdir / "output.PNG", format="PNG") + + # Save init image + if init_image is not None: + init_image.save(runoutputdir / "init-image.JPEG", format="JPEG") + + # Save image prompts + for count, image_prompt in enumerate(image_prompts): + image_prompt.save( + runoutputdir / f"image-prompt-{count}.JPEG", format="JPEG" + ) + + # Save animation + shutil.copy("temp.mp4", runoutputdir / "anim.mp4") + + # Save metadata + with open(runoutputdir / "details.json", "w") as f: + json.dump( + { + "run_id": run_id, + "num_steps": step_counter, + "planned_num_steps": num_steps, + "text_input": text_input, + "init_image": False if init_image is None else True, + "image_prompts": False if len(image_prompts) == 0 else True, + "continue_prev_run": continue_prev_run, + "prev_run_id": prev_run_id, + "seed": run.seed, + "Xdim": image_x, + "ydim": image_y, + "vqgan_ckpt": vqgan_ckpt, + "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), + "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), + }, + f, + indent=4, + ) + + status_text.text("Done!") # End of run + + except st.script_runner.StopException as e: + # Dump output to dashboard + print(f"Received Streamlit StopException") + status_text.text("Execution interruped, dumping outputs ...") + writer = imageio.get_writer("temp.mp4", fps=24) + for frame in frames: + writer.append_data(frame) + writer.close() + + # TODO: Make the following DRY + # Save to output folder if run completed + runoutputdir = outputdir / ( + run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id + ) + runoutputdir.mkdir() + + # Save final image + im.save(runoutputdir / "output.PNG", format="PNG") + + # Save init image + if init_image is not None: + init_image.save(runoutputdir / "init-image.JPEG", format="JPEG") + + # Save image prompts + for count, image_prompt in enumerate(image_prompts): + image_prompt.save( + runoutputdir / f"image-prompt-{count}.JPEG", format="JPEG" + ) + + # Save animation + shutil.copy("temp.mp4", runoutputdir / "anim.mp4") + + # Save metadata + with open(runoutputdir / "details.json", "w") as f: + json.dump( + { + "run_id": run_id, + "num_steps": step_counter, + "planned_num_steps": num_steps, + "text_input": text_input, + "init_image": False if init_image is None else True, + "image_prompts": False if len(image_prompts) == 0 else True, + "continue_prev_run": continue_prev_run, + "prev_run_id": prev_run_id, + "seed": run.seed, + "Xdim": image_x, + "ydim": image_y, + "vqgan_ckpt": vqgan_ckpt, + "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), + "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), + }, + f, + indent=4, + ) + status_text.text("Done!") # End of run + + +if __name__ == "__main__": + defaults = OmegaConf.load("defaults.yaml") + outputdir = Path("output") + if not outputdir.exists(): + outputdir.mkdir() + + st.set_page_config(page_title="VQGAN-CLIP playground") + st.title("VQGAN-CLIP playground") + + # Determine what weights are available in `assets/` + weights_dir = Path("assets").resolve() + available_weight_ckpts = list(weights_dir.glob("*.ckpt")) + available_weight_configs = list(weights_dir.glob("*.yaml")) + available_weights = [ + i.stem + for i in available_weight_ckpts + if i.stem in [j.stem for j in available_weight_configs] + ] + + # Set vqgan_imagenet_f16_1024 as default if possible + if "vqgan_imagenet_f16_1024" in available_weights: + default_weight_index = available_weights.index("vqgan_imagenet_f16_1024") + else: + default_weight_index = 0 + + # Start of input form + with st.form("form-inputs"): + # Only element not in the sidebar, but in the form + text_input = st.text_input( + "Text prompt", + help="VQGAN-CLIP will generate an image that best fits the prompt", + ) + radio = st.sidebar.radio( + "Model weights", + available_weights, + index=default_weight_index, + help="Choose which weights to load, trained on different datasets. Make sure the weights and configs are downloaded to `assets/` as per the README!", + ) + num_steps = st.sidebar.number_input( + "Num steps", + value=defaults["num_steps"], + min_value=-1, + max_value=None, + step=1, + help="Specify -1 to run indefinitely. Use Streamlit's stop button in the top right corner to terminate execution. The exception is caught so the most recent output will be dumped to dashboard", + ) + + image_x = st.sidebar.number_input( + "Xdim", value=defaults["Xdim"], help="Width of output image, in pixels" + ) + image_y = st.sidebar.number_input( + "ydim", value=defaults["ydim"], help="Height of output image, in pixels" + ) + set_seed = st.sidebar.checkbox( + "Set seed", + value=defaults["set_seed"], + help="Check to set random seed for reproducibility. Will add option to specify seed", + ) + + seed_widget = st.sidebar.empty() + if set_seed is True: + # Use text_input as number_input relies on JS + # which can't natively handle large numbers + # torch.seed() generates int w/ 19 or 20 chars! + seed_str = seed_widget.text_input( + "Seed", value=str(defaults["seed"]), help="Random seed to use" + ) + try: + seed = int(seed_str) + except ValueError as e: + st.error("seed input needs to be int") + else: + seed = None + + use_custom_starting_image = st.sidebar.checkbox( + "Use starting image", + value=defaults["use_starting_image"], + help="Check to add a starting image to the network", + ) + + starting_image_widget = st.sidebar.empty() + if use_custom_starting_image is True: + init_image = starting_image_widget.file_uploader( + "Upload starting image", + type=["png", "jpeg", "jpg"], + accept_multiple_files=False, + help="Starting image for the network, will be resized to fit specified dimensions", + ) + # Convert from UploadedFile object to PIL Image + if init_image is not None: + init_image: Image.Image = Image.open(init_image).convert( + "RGB" + ) # just to be sure + else: + init_image = None + + use_image_prompts = st.sidebar.checkbox( + "Add image prompt(s)", + value=defaults["use_image_prompts"], + help="Check to add image prompt(s), conditions the network similar to the text prompt", + ) + + image_prompts_widget = st.sidebar.empty() + if use_image_prompts is True: + image_prompts = image_prompts_widget.file_uploader( + "Upload image prompts(s)", + type=["png", "jpeg", "jpg"], + accept_multiple_files=True, + help="Image prompt(s) for the network, will be resized to fit specified dimensions", + ) + # Convert from UploadedFile object to PIL Image + if len(image_prompts) != 0: + image_prompts = [Image.open(i).convert("RGB") for i in image_prompts] + else: + image_prompts = [] + + continue_prev_run = st.sidebar.checkbox( + "Continue previous run", + value=defaults["continue_prev_run"], + help="Use existing image and existing weights for the next run. If yes, ignores 'Use starting image'", + ) + submitted = st.form_submit_button("Run!") + # End of form + + status_text = st.empty() + status_text.text("Pending input prompt") + step_progress_bar = st.progress(0) + + im_display_slot = st.empty() + vid_display_slot = st.empty() + debug_slot = st.empty() + + if "prev_im" in st.session_state: + im_display_slot.image( + st.session_state["prev_im"], caption="Output image", output_format="PNG" + ) + + with st.beta_expander("Expand for README"): + with open("README.md", "r") as f: + # description = f.read() + # Preprocess links to redirect to github + # Thank you https://discuss.streamlit.io/u/asehmi, works like a charm! + # ref: https://discuss.streamlit.io/t/image-in-markdown/13274/8 + readme_lines = f.readlines() + readme_buffer = [] + images = ["docs/ui.jpeg", "docs/four-seasons-20210808.png"] + for line in readme_lines: + readme_buffer.append(line) + for image in images: + if image in line: + st.markdown(" ".join(readme_buffer[:-1])) + st.image( + f"https://raw.githubusercontent.com/tnwei/vqgan-clip-app/main/{image}" + ) + readme_buffer.clear() + st.markdown(" ".join(readme_buffer)) + + # st.write(description) + + if submitted: + # debug_slot.write(st.session_state) # DEBUG + status_text.text("Loading weights ...") + generate_image( + # Inputs + text_input=text_input, + vqgan_ckpt=radio, + num_steps=num_steps, + image_x=int(image_x), + image_y=int(image_y), + seed=int(seed) if set_seed is True else None, + init_image=init_image, + image_prompts=image_prompts, + continue_prev_run=continue_prev_run, + ) + vid_display_slot.video("temp.mp4") + # debug_slot.write(st.session_state) # DEBUG diff --git a/mselogic.py b/mselogic.py new file mode 100644 index 0000000..844a583 --- /dev/null +++ b/mselogic.py @@ -0,0 +1,508 @@ +from typing import Optional, List, Tuple +from PIL import Image +import argparse +import clip +from vqgan_utils import ( + load_vqgan_model, + MakeCutouts, + parse_prompt, + resize_image, + Prompt, + synth, + checkin, +) +import torch +from torchvision.transforms import functional as TF +import torch.nn as nn +from torch.nn import functional as F +from torch import optim +from torchvision import transforms +from mse_vqgan_utils import synth_mse, MSEMakeCutouts, noise_gen +import kornia.augmentation as K + + +class Run: + """ + Subclass this to house your own implementation of CLIP-based image generation + models within the UI + """ + + def __init__(self): + """ + Set up the run's config here + """ + pass + + def load_model(self): + """ + Load models here. Separated this from __init__ to allow loading model state + from a previous run + """ + pass + + def model_init(self): + """ + Continue run setup, for items that require the models to be in=place. + Call once after load_model + """ + pass + + def iterate(self): + """ + Place iteration logic here. Outputs results for human consumption at + every step. + """ + pass + + +class VQGANCLIPRun(Run): + def __init__( + # Inputs + self, + text_input: str = "the first day of the waters", + vqgan_ckpt: str = "vqgan_imagenet_f16_16384", + num_steps: int = 300, + image_x: int = 300, + image_y: int = 300, + init_image: Optional[Image.Image] = None, + image_prompts: List[Image.Image] = [], + continue_prev_run: bool = False, + seed: Optional[int] = None, + ## **kwargs, # Use this to receive Streamlit objects ## Call from main UI + ) -> None: + super().__init__() + self.text_input = text_input + self.vqgan_ckpt = vqgan_ckpt + self.num_steps = num_steps + self.image_x = image_x + self.image_y = image_y + self.init_image = init_image + self.image_prompts = image_prompts + self.continue_prev_run = continue_prev_run + self.seed = seed + + # Setup ------------------------------------------------------------------------------ + # Split text by "|" symbol + texts = [phrase.strip() for phrase in text_input.split("|")] + if texts == [""]: + texts = [] + + # Leaving most of this untouched + self.args = argparse.Namespace( + prompts=texts, + image_prompts=image_prompts, + noise_prompt_seeds=[], + noise_prompt_weights=[], + size=[int(image_x), int(image_y)], + init_image=init_image, + init_weight=0.0, + # clip.available_models() + # ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] + # Visual Transformer seems to be the smallest + clip_model="ViT-B/32", + vqgan_config=f"assets/{vqgan_ckpt}.yaml", + vqgan_checkpoint=f"assets/{vqgan_ckpt}.ckpt", + step_size=0.05, + cutn=64, + cut_pow=1.0, + display_freq=50, + seed=seed, + ) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.device = device + print("Using device:", device) + + def load_model( + self, prev_model: nn.Module = None, prev_perceptor: nn.Module = None + ) -> Optional[Tuple[nn.Module, nn.Module]]: + if self.continue_prev_run is True: + self.model = prev_model + self.perceptor = prev_perceptor + return None + + else: + self.model = load_vqgan_model( + self.args.vqgan_config, self.args.vqgan_checkpoint + ).to(self.device) + + self.perceptor = ( + clip.load(self.args.clip_model, jit=False)[0] + .eval() + .requires_grad_(False) + .to(self.device) + ) + + return self.model, self.perceptor + + def model_init(self, init_image: Image.Image = None) -> None: + cut_size = self.perceptor.visual.input_resolution + e_dim = self.model.quantize.e_dim + f = 2 ** (self.model.decoder.num_resolutions - 1) + self.make_cutouts = MakeCutouts( + cut_size, self.args.cutn, cut_pow=self.args.cut_pow + ) + n_toks = self.model.quantize.n_e + toksX, toksY = self.args.size[0] // f, self.args.size[1] // f + sideX, sideY = toksX * f, toksY * f + self.z_min = self.model.quantize.embedding.weight.min(dim=0).values[ + None, :, None, None + ] + self.z_max = self.model.quantize.embedding.weight.max(dim=0).values[ + None, :, None, None + ] + + if self.seed is not None: + torch.manual_seed(self.seed) + else: + self.seed = torch.seed() # Trigger a seed, retrieve the utilized seed + + # Initialization order: continue_prev_im, init_image, then only random init + if init_image is not None: + init_image = init_image.resize((sideX, sideY), Image.LANCZOS) + self.z, *_ = self.model.encode( + TF.to_tensor(init_image).to(self.device).unsqueeze(0) * 2 - 1 + ) + elif self.args.init_image: + pil_image = self.args.init_image + pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) + self.z, *_ = self.model.encode( + TF.to_tensor(pil_image).to(self.device).unsqueeze(0) * 2 - 1 + ) + else: + one_hot = F.one_hot( + torch.randint(n_toks, [toksY * toksX], device=self.device), n_toks + ).float() + self.z = one_hot @ self.model.quantize.embedding.weight + self.z = self.z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) + self.z_orig = self.z.clone() + self.z.requires_grad_(True) + self.opt = optim.Adam([self.z], lr=self.args.step_size) + + self.normalize = transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ) + + self.pMs = [] + + for prompt in self.args.prompts: + txt, weight, stop = parse_prompt(prompt) + embed = self.perceptor.encode_text( + clip.tokenize(txt).to(self.device) + ).float() + self.pMs.append(Prompt(embed, weight, stop).to(self.device)) + + for uploaded_image in self.args.image_prompts: + # path, weight, stop = parse_prompt(prompt) + # img = resize_image(Image.open(fetch(path)).convert("RGB"), (sideX, sideY)) + img = resize_image(uploaded_image.convert("RGB"), (sideX, sideY)) + batch = self.make_cutouts(TF.to_tensor(img).unsqueeze(0).to(self.device)) + embed = self.perceptor.encode_image(self.normalize(batch)).float() + self.pMs.append(Prompt(embed, weight, stop).to(self.device)) + + for seed, weight in zip( + self.args.noise_prompt_seeds, self.args.noise_prompt_weights + ): + gen = torch.Generator().manual_seed(seed) + embed = torch.empty([1, self.perceptor.visual.output_dim]).normal_( + generator=gen + ) + self.pMs.append(Prompt(embed, weight).to(self.device)) + + def _ascend_txt(self) -> List: + out = synth(self.model, self.z) + iii = self.perceptor.encode_image( + self.normalize(self.make_cutouts(out)) + ).float() + + result = [] + + if self.args.init_weight: + result.append(F.mse_loss(self.z, self.z_orig) * self.args.init_weight / 2) + + for prompt in self.pMs: + result.append(prompt(iii)) + + return result + + def iterate(self) -> Tuple[List[float], Image.Image]: + # Forward prop + self.opt.zero_grad() + losses = self._ascend_txt() + + # Grab an image + im: Image.Image = checkin(self.model, self.z) + + # Backprop + loss = sum(losses) + loss.backward() + self.opt.step() + with torch.no_grad(): + self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max)) + + # Output stuff useful for humans + return [loss.item() for loss in losses], im + + +class MSEVQGANCLIPRun(VQGANCLIPRun): + def __init__( + # Inputs + self, + text_input: str = "the first day of the waters", + vqgan_ckpt: str = "vqgan_imagenet_f16_16384", + num_steps: int = 300, + image_x: int = 300, + image_y: int = 300, + init_image: Optional[Image.Image] = None, + image_prompts: List[Image.Image] = [], + continue_prev_run: bool = False, + seed: Optional[int] = None, + # MSE VQGAN-CLIP options + use_augs: bool = True, + noise_fac: float = 0.1, + use_noise: Optional[float] = None, + mse_withzeros=True, + mse_decay_rate=50, + mse_epoches=5, + ) -> None: + super().__init__() + self.text_input = text_input + self.vqgan_ckpt = vqgan_ckpt + self.num_steps = num_steps + self.image_x = image_x + self.image_y = image_y + self.init_image = init_image + self.image_prompts = image_prompts + self.continue_prev_run = continue_prev_run + self.seed = seed + + # Setup ------------------------------------------------------------------------------ + # Split text by "|" symbol + texts = [phrase.strip() for phrase in text_input.split("|")] + if texts == [""]: + texts = [] + + # Leaving most of this untouched + self.args = argparse.Namespace( + prompts=texts, + image_prompts=image_prompts, + noise_prompt_seeds=[], + noise_prompt_weights=[], + size=[int(image_x), int(image_y)], + init_image=init_image, + init_weight=0.5, # differ from standard VQGAN CLIP + clip_model="ViT-B/32", + vqgan_config=f"assets/{vqgan_ckpt}.yaml", + vqgan_checkpoint=f"assets/{vqgan_ckpt}.ckpt", + step_size=0.05, + cutn=64, + cut_pow=1.0, + display_freq=50, + seed=seed, + ) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.device = device + print("Using device:", device) + + # TODO: MSE regularized options here + self.iterate_counter = 0 + self.use_augs = use_augs + self.noise_fac = noise_fac + self.use_noise = use_noise + self.mse_withzeros = mse_withzeros + self.mse_decay_rate = mse_decay_rate + self.mse_epoches = mse_epoches + + # Added init for MSE VQGAN + if self.args.init_weight: + self.mse_decay = self.args.init_weight / self.mse_epoches + + self.mse_weight = self.args.init_weight + + self.augs = nn.Sequential( + K.RandomHorizontalFlip(p=0.5), + K.RandomAffine( + degrees=30, translate=0.1, p=0.8, padding_mode="border" + ), # padding_mode=2 + K.RandomPerspective(0.2, p=0.4), + K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), + ) + + def model_init(self, init_image: Image.Image = None) -> None: + cut_size = self.perceptor.visual.input_resolution + + if self.args.vqgan_checkpoint == "vqgan_openimages_f16_8192.ckpt": + self.e_dim = 256 + self.n_toks = self.model.quantize.n_embed + self.z_min = self.model.quantize.embed.weight.min(dim=0).values[ + None, :, None, None + ] + self.z_max = self.model.quantize.embed.weight.max(dim=0).values[ + None, :, None, None + ] + else: + self.e_dim = self.model.quantize.e_dim + self.n_toks = self.model.quantize.n_e + self.z_min = self.model.quantize.embedding.weight.min(dim=0).values[ + None, :, None, None + ] + self.z_max = self.model.quantize.embedding.weight.max(dim=0).values[ + None, :, None, None + ] + + f = 2 ** (self.model.decoder.num_resolutions - 1) + self.make_cutouts = MSEMakeCutouts( + cut_size, + self.args.cutn, + cut_pow=self.args.cut_pow, + augs=self.augs if self.use_augs is True else None, + ) + n_toks = self.model.quantize.n_e + toksX, toksY = self.args.size[0] // f, self.args.size[1] // f + sideX, sideY = toksX * f, toksY * f # notebook used 16 instead of f here + + if self.seed is not None: + torch.manual_seed(self.seed) + else: + self.seed = torch.seed() # Trigger a seed, retrieve the utilized seed + + # Initialization order: continue_prev_im, init_image, then only random init + if init_image is not None: + init_image = init_image.resize((sideX, sideY), Image.LANCZOS) + init_image = TF.to_tensor(init_image) + + if self.args.use_noise: + init_image = init_image + self.args.use_noise * torch.randn_like( + init_image + ) + + self.z, *_ = self.model.encode( + TF.to_tensor(init_image).to(self.device).unsqueeze(0) * 2 - 1 + ) + elif self.args.init_image: + pil_image = self.args.init_image + pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) + pil_image = TF.to_tensor(pil_image) + + if self.args.use_noise: + pil_image = pil_image + self.args.use_noise * torch.randn_like( + pil_image + ) + + self.z, *_ = self.model.encode( + TF.to_tensor(pil_image).to(self.device).unsqueeze(0) * 2 - 1 + ) + else: + one_hot = F.one_hot( + torch.randint(n_toks, [toksY * toksX], device=self.device), n_toks + ).float() + + if self.args.vqgan_checkpoint == "vqgan_openimages_f16_8192.ckpt": + self.z = one_hot @ self.model.quantize.embed.weight + else: + self.z = one_hot @ self.model.quantize.embedding.weight + + self.z = self.z.view([-1, toksY, toksX, self.e_dim]).permute(0, 3, 1, 2) + + if self.mse_withzeros and not self.args.init_image: + self.z_orig = torch.zeros_like(self.z) + else: + self.z_orig = self.z.clone() + + self.z.requires_grad_(True) + self.opt = optim.Adam([self.z], lr=self.args.step_size) + + self.normalize = transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ) + + self.pMs = [] + + for prompt in self.args.prompts: + txt, weight, stop = parse_prompt(prompt) + embed = self.perceptor.encode_text( + clip.tokenize(txt).to(self.device) + ).float() + self.pMs.append(Prompt(embed, weight, stop).to(self.device)) + + for uploaded_image in self.args.image_prompts: + # path, weight, stop = parse_prompt(prompt) + # img = resize_image(Image.open(fetch(path)).convert("RGB"), (sideX, sideY)) + img = resize_image(uploaded_image.convert("RGB"), (sideX, sideY)) + batch = self.make_cutouts(TF.to_tensor(img).unsqueeze(0).to(self.device)) + embed = self.perceptor.encode_image(self.normalize(batch)).float() + self.pMs.append(Prompt(embed, weight, stop).to(self.device)) + + for seed, weight in zip( + self.args.noise_prompt_seeds, self.args.noise_prompt_weights + ): + gen = torch.Generator().manual_seed(seed) + embed = torch.empty([1, self.perceptor.visual.output_dim]).normal_( + generator=gen + ) + self.pMs.append(Prompt(embed, weight).to(self.device)) + + def _ascend_txt(self) -> List: + out = synth_mse( + self.model, + self.z, + quantize=True, + is_openimages_f16_8192=True + if self.args.vqgan_checkpoint == "vqgan_openimages_f16_8192.ckpt" + else False, + ) + + cutouts = self.make_cutouts(out) + iii = self.perceptor.encode_image(self.normalize(cutouts)).float() + + result = [] + + if self.args.init_weight: + result.append(F.mse_loss(self.z, self.z_orig) * self.mse_weight / 2) + # result.append(F.mse_loss(z, z_orig) * ((1/torch.tensor((i)*2 + 1))*mse_weight) / 2) + + with torch.no_grad(): + if ( + self.iterate_counter > 0 + and self.iterate_counter % self.mse_decay_rate == 0 + and self.iterate_counter <= self.mse_decay_rate * self.mse_epoches + ): + + if ( + self.mse_weight - self.mse_decay > 0 + and self.mse_weight - self.mse_decay >= self.mse_decay + ): + self.mse_weight = self.mse_weight - self.mse_decay + print(f"updated mse weight: {self.mse_weight}") + else: + self.mse_weight = 0 + print(f"updated mse weight: {self.mse_weight}") + + for prompt in self.pMs: + result.append(prompt(iii)) + + return result + + def iterate(self) -> Tuple[List[float], Image.Image]: + # Forward prop + self.opt.zero_grad() + losses = self._ascend_txt() + + # Grab an image + im: Image.Image = checkin(self.model, self.z) + + # Backprop + loss = sum(losses) + loss.backward() + self.opt.step() + # with torch.no_grad(): + # self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max)) + + # Advance iteration counter + self.iterate_counter += 1 + + # Output stuff useful for humans + return [loss.item() for loss in losses], im From 59c142251cf767b1c7a3d1b79906577c2c842750 Mon Sep 17 00:00:00 2001 From: Tan Nian Wei Date: Wed, 13 Oct 2021 00:45:54 +0800 Subject: [PATCH 2/9] feat: added widget for mse init weight --- defaults.yaml | 1 + mseapp.py | 10 +++ mselogic.py | 231 ++------------------------------------------------ 3 files changed, 16 insertions(+), 226 deletions(-) diff --git a/defaults.yaml b/defaults.yaml index 1ce66b6..793cd69 100644 --- a/defaults.yaml +++ b/defaults.yaml @@ -7,3 +7,4 @@ seed: 0 use_starting_image: false use_image_prompts: false continue_prev_run: false +mse_weight: 0.0 \ No newline at end of file diff --git a/mseapp.py b/mseapp.py index 6ce5eb7..7ebd48c 100644 --- a/mseapp.py +++ b/mseapp.py @@ -34,6 +34,7 @@ def generate_image( image_prompts: List[Image.Image] = [], continue_prev_run: bool = False, seed: Optional[int] = None, + mse_weight: float = 0, ) -> None: ### Init ------------------------------------------------------------------- @@ -53,6 +54,7 @@ def generate_image( mse_withzeros=True, mse_decay_rate=50, mse_epoches=5, + mse_weight=mse_weight, ) ### Load model ------------------------------------------------------------- @@ -362,6 +364,13 @@ def generate_image( value=defaults["continue_prev_run"], help="Use existing image and existing weights for the next run. If yes, ignores 'Use starting image'", ) + mse_weight = st.sidebar.number_input( + "MSE weight", + value=defaults["mse_weight"], + min_value=0.0, + step=0.05, + help="Set weights for MSE regularization", + ) submitted = st.form_submit_button("Run!") # End of form @@ -414,6 +423,7 @@ def generate_image( init_image=init_image, image_prompts=image_prompts, continue_prev_run=continue_prev_run, + mse_weight=mse_weight, ) vid_display_slot.video("temp.mp4") # debug_slot.write(st.session_state) # DEBUG diff --git a/mselogic.py b/mselogic.py index 844a583..9d2aec7 100644 --- a/mselogic.py +++ b/mselogic.py @@ -19,230 +19,7 @@ from torchvision import transforms from mse_vqgan_utils import synth_mse, MSEMakeCutouts, noise_gen import kornia.augmentation as K - - -class Run: - """ - Subclass this to house your own implementation of CLIP-based image generation - models within the UI - """ - - def __init__(self): - """ - Set up the run's config here - """ - pass - - def load_model(self): - """ - Load models here. Separated this from __init__ to allow loading model state - from a previous run - """ - pass - - def model_init(self): - """ - Continue run setup, for items that require the models to be in=place. - Call once after load_model - """ - pass - - def iterate(self): - """ - Place iteration logic here. Outputs results for human consumption at - every step. - """ - pass - - -class VQGANCLIPRun(Run): - def __init__( - # Inputs - self, - text_input: str = "the first day of the waters", - vqgan_ckpt: str = "vqgan_imagenet_f16_16384", - num_steps: int = 300, - image_x: int = 300, - image_y: int = 300, - init_image: Optional[Image.Image] = None, - image_prompts: List[Image.Image] = [], - continue_prev_run: bool = False, - seed: Optional[int] = None, - ## **kwargs, # Use this to receive Streamlit objects ## Call from main UI - ) -> None: - super().__init__() - self.text_input = text_input - self.vqgan_ckpt = vqgan_ckpt - self.num_steps = num_steps - self.image_x = image_x - self.image_y = image_y - self.init_image = init_image - self.image_prompts = image_prompts - self.continue_prev_run = continue_prev_run - self.seed = seed - - # Setup ------------------------------------------------------------------------------ - # Split text by "|" symbol - texts = [phrase.strip() for phrase in text_input.split("|")] - if texts == [""]: - texts = [] - - # Leaving most of this untouched - self.args = argparse.Namespace( - prompts=texts, - image_prompts=image_prompts, - noise_prompt_seeds=[], - noise_prompt_weights=[], - size=[int(image_x), int(image_y)], - init_image=init_image, - init_weight=0.0, - # clip.available_models() - # ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] - # Visual Transformer seems to be the smallest - clip_model="ViT-B/32", - vqgan_config=f"assets/{vqgan_ckpt}.yaml", - vqgan_checkpoint=f"assets/{vqgan_ckpt}.ckpt", - step_size=0.05, - cutn=64, - cut_pow=1.0, - display_freq=50, - seed=seed, - ) - - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self.device = device - print("Using device:", device) - - def load_model( - self, prev_model: nn.Module = None, prev_perceptor: nn.Module = None - ) -> Optional[Tuple[nn.Module, nn.Module]]: - if self.continue_prev_run is True: - self.model = prev_model - self.perceptor = prev_perceptor - return None - - else: - self.model = load_vqgan_model( - self.args.vqgan_config, self.args.vqgan_checkpoint - ).to(self.device) - - self.perceptor = ( - clip.load(self.args.clip_model, jit=False)[0] - .eval() - .requires_grad_(False) - .to(self.device) - ) - - return self.model, self.perceptor - - def model_init(self, init_image: Image.Image = None) -> None: - cut_size = self.perceptor.visual.input_resolution - e_dim = self.model.quantize.e_dim - f = 2 ** (self.model.decoder.num_resolutions - 1) - self.make_cutouts = MakeCutouts( - cut_size, self.args.cutn, cut_pow=self.args.cut_pow - ) - n_toks = self.model.quantize.n_e - toksX, toksY = self.args.size[0] // f, self.args.size[1] // f - sideX, sideY = toksX * f, toksY * f - self.z_min = self.model.quantize.embedding.weight.min(dim=0).values[ - None, :, None, None - ] - self.z_max = self.model.quantize.embedding.weight.max(dim=0).values[ - None, :, None, None - ] - - if self.seed is not None: - torch.manual_seed(self.seed) - else: - self.seed = torch.seed() # Trigger a seed, retrieve the utilized seed - - # Initialization order: continue_prev_im, init_image, then only random init - if init_image is not None: - init_image = init_image.resize((sideX, sideY), Image.LANCZOS) - self.z, *_ = self.model.encode( - TF.to_tensor(init_image).to(self.device).unsqueeze(0) * 2 - 1 - ) - elif self.args.init_image: - pil_image = self.args.init_image - pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) - self.z, *_ = self.model.encode( - TF.to_tensor(pil_image).to(self.device).unsqueeze(0) * 2 - 1 - ) - else: - one_hot = F.one_hot( - torch.randint(n_toks, [toksY * toksX], device=self.device), n_toks - ).float() - self.z = one_hot @ self.model.quantize.embedding.weight - self.z = self.z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) - self.z_orig = self.z.clone() - self.z.requires_grad_(True) - self.opt = optim.Adam([self.z], lr=self.args.step_size) - - self.normalize = transforms.Normalize( - mean=[0.48145466, 0.4578275, 0.40821073], - std=[0.26862954, 0.26130258, 0.27577711], - ) - - self.pMs = [] - - for prompt in self.args.prompts: - txt, weight, stop = parse_prompt(prompt) - embed = self.perceptor.encode_text( - clip.tokenize(txt).to(self.device) - ).float() - self.pMs.append(Prompt(embed, weight, stop).to(self.device)) - - for uploaded_image in self.args.image_prompts: - # path, weight, stop = parse_prompt(prompt) - # img = resize_image(Image.open(fetch(path)).convert("RGB"), (sideX, sideY)) - img = resize_image(uploaded_image.convert("RGB"), (sideX, sideY)) - batch = self.make_cutouts(TF.to_tensor(img).unsqueeze(0).to(self.device)) - embed = self.perceptor.encode_image(self.normalize(batch)).float() - self.pMs.append(Prompt(embed, weight, stop).to(self.device)) - - for seed, weight in zip( - self.args.noise_prompt_seeds, self.args.noise_prompt_weights - ): - gen = torch.Generator().manual_seed(seed) - embed = torch.empty([1, self.perceptor.visual.output_dim]).normal_( - generator=gen - ) - self.pMs.append(Prompt(embed, weight).to(self.device)) - - def _ascend_txt(self) -> List: - out = synth(self.model, self.z) - iii = self.perceptor.encode_image( - self.normalize(self.make_cutouts(out)) - ).float() - - result = [] - - if self.args.init_weight: - result.append(F.mse_loss(self.z, self.z_orig) * self.args.init_weight / 2) - - for prompt in self.pMs: - result.append(prompt(iii)) - - return result - - def iterate(self) -> Tuple[List[float], Image.Image]: - # Forward prop - self.opt.zero_grad() - losses = self._ascend_txt() - - # Grab an image - im: Image.Image = checkin(self.model, self.z) - - # Backprop - loss = sum(losses) - loss.backward() - self.opt.step() - with torch.no_grad(): - self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max)) - - # Output stuff useful for humans - return [loss.item() for loss in losses], im +from logic import VQGANCLIPRun class MSEVQGANCLIPRun(VQGANCLIPRun): @@ -258,13 +35,15 @@ def __init__( image_prompts: List[Image.Image] = [], continue_prev_run: bool = False, seed: Optional[int] = None, - # MSE VQGAN-CLIP options + # MSE VQGAN-CLIP options from notebook use_augs: bool = True, noise_fac: float = 0.1, use_noise: Optional[float] = None, mse_withzeros=True, mse_decay_rate=50, mse_epoches=5, + # Added options + mse_weight=0.5, ) -> None: super().__init__() self.text_input = text_input @@ -291,7 +70,7 @@ def __init__( noise_prompt_weights=[], size=[int(image_x), int(image_y)], init_image=init_image, - init_weight=0.5, # differ from standard VQGAN CLIP + init_weight=mse_weight, clip_model="ViT-B/32", vqgan_config=f"assets/{vqgan_ckpt}.yaml", vqgan_checkpoint=f"assets/{vqgan_ckpt}.ckpt", From 723ca92eb8136840a5d39fb310106b69138ea0bf Mon Sep 17 00:00:00 2001 From: Tan Nian Wei Date: Wed, 13 Oct 2021 01:22:52 +0800 Subject: [PATCH 3/9] feat: added MSE schedule --- mseapp.py | 24 +++++++++++++++++++++--- mselogic.py | 43 ++++++++++++++++++++++++------------------- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/mseapp.py b/mseapp.py index 7ebd48c..d4d6ae8 100644 --- a/mseapp.py +++ b/mseapp.py @@ -35,6 +35,8 @@ def generate_image( continue_prev_run: bool = False, seed: Optional[int] = None, mse_weight: float = 0, + mse_weight_decay: float = 0, + mse_weight_decay_steps: int = 0, ) -> None: ### Init ------------------------------------------------------------------- @@ -52,9 +54,9 @@ def generate_image( noise_fac=0.1, use_noise=None, mse_withzeros=True, - mse_decay_rate=50, - mse_epoches=5, mse_weight=mse_weight, + mse_weight_decay=mse_weight_decay, + mse_weight_decay_steps=mse_weight_decay_steps, ) ### Load model ------------------------------------------------------------- @@ -367,10 +369,24 @@ def generate_image( mse_weight = st.sidebar.number_input( "MSE weight", value=defaults["mse_weight"], - min_value=0.0, + # min_value=0.0, # leave this out to allow creativity step=0.05, help="Set weights for MSE regularization", ) + mse_weight_decay = st.sidebar.number_input( + "Decay MSE weight by ...", + value=0.0, + # min_value=0.0, # leave this out to allow creativity + step=0.05, + help="Subtracts MSE weight by this amount at every step change. MSE weight change stops at zero", + ) + mse_weight_decay_steps = st.sidebar.number_input( + "... every N steps", + value=0, + min_value=0, + step=1, + help="Number of steps to subtract MSE weight. Leave zero for no weight decay", + ) submitted = st.form_submit_button("Run!") # End of form @@ -424,6 +440,8 @@ def generate_image( image_prompts=image_prompts, continue_prev_run=continue_prev_run, mse_weight=mse_weight, + mse_weight_decay=mse_weight_decay, + mse_weight_decay_steps=mse_weight_decay_steps, ) vid_display_slot.video("temp.mp4") # debug_slot.write(st.session_state) # DEBUG diff --git a/mselogic.py b/mselogic.py index 9d2aec7..6ce02d0 100644 --- a/mselogic.py +++ b/mselogic.py @@ -40,10 +40,12 @@ def __init__( noise_fac: float = 0.1, use_noise: Optional[float] = None, mse_withzeros=True, - mse_decay_rate=50, - mse_epoches=5, + # mse_decay_rate=50, + # mse_epoches=5, # Added options mse_weight=0.5, + mse_weight_decay=0.1, + mse_weight_decay_steps=50, ) -> None: super().__init__() self.text_input = text_input @@ -91,14 +93,11 @@ def __init__( self.noise_fac = noise_fac self.use_noise = use_noise self.mse_withzeros = mse_withzeros - self.mse_decay_rate = mse_decay_rate - self.mse_epoches = mse_epoches - - # Added init for MSE VQGAN - if self.args.init_weight: - self.mse_decay = self.args.init_weight / self.mse_epoches + self.mse_weight_decay = mse_weight_decay + self.mse_weight_decay_steps = mse_weight_decay_steps self.mse_weight = self.args.init_weight + self.init_mse_weight = self.mse_weight self.augs = nn.Sequential( K.RandomHorizontalFlip(p=0.5), @@ -241,24 +240,30 @@ def _ascend_txt(self) -> List: if self.args.init_weight: result.append(F.mse_loss(self.z, self.z_orig) * self.mse_weight / 2) - # result.append(F.mse_loss(z, z_orig) * ((1/torch.tensor((i)*2 + 1))*mse_weight) / 2) with torch.no_grad(): + # if not the first step + # and is time for step change + # and both weight decay steps and magnitude are nonzero + # and MSE isn't zero already if ( self.iterate_counter > 0 - and self.iterate_counter % self.mse_decay_rate == 0 - and self.iterate_counter <= self.mse_decay_rate * self.mse_epoches + and self.iterate_counter % self.mse_weight_decay_steps == 0 + and self.mse_weight_decay != 0 + and self.mse_weight_decay_steps != 0 + and self.mse_weight != 0 ): + self.mse_weight = self.mse_weight - self.mse_weight_decay - if ( - self.mse_weight - self.mse_decay > 0 - and self.mse_weight - self.mse_decay >= self.mse_decay - ): - self.mse_weight = self.mse_weight - self.mse_decay - print(f"updated mse weight: {self.mse_weight}") + # Don't allow changing sign + # Basically, caps MSE at zero if decreasing from positive + # But, also prevents MSE from becoming positive if -MSE intended + if self.init_mse_weight > 0: + self.mse_weight = max(self.mse_weight, 0) else: - self.mse_weight = 0 - print(f"updated mse weight: {self.mse_weight}") + self.mse_weight = min(self.mse_weight, 0) + + print(f"updated mse weight: {self.mse_weight}") for prompt in self.pMs: result.append(prompt(iii)) From f4e86bbeba3c3bacb858c2acf2a860762c2e9f34 Mon Sep 17 00:00:00 2001 From: Tan Nian Wei Date: Thu, 14 Oct 2021 00:42:45 +0800 Subject: [PATCH 4/9] feat: nested mse option, added to metadata output --- mseapp.py | 57 ++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/mseapp.py b/mseapp.py index d4d6ae8..d1bf330 100644 --- a/mseapp.py +++ b/mseapp.py @@ -181,6 +181,9 @@ def generate_image( "vqgan_ckpt": vqgan_ckpt, "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), + "mse_weight": mse_weight, + "mse_weight_decay": mse_weight_decay, + "mse_weight_decay_steps": mse_weight_decay_steps, }, f, indent=4, @@ -238,6 +241,9 @@ def generate_image( "vqgan_ckpt": vqgan_ckpt, "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), + "mse_weight": mse_weight, + "mse_weight_decay": mse_weight_decay, + "mse_weight_decay_steps": mse_weight_decay_steps, }, f, indent=4, @@ -366,27 +372,38 @@ def generate_image( value=defaults["continue_prev_run"], help="Use existing image and existing weights for the next run. If yes, ignores 'Use starting image'", ) - mse_weight = st.sidebar.number_input( - "MSE weight", - value=defaults["mse_weight"], - # min_value=0.0, # leave this out to allow creativity - step=0.05, - help="Set weights for MSE regularization", - ) - mse_weight_decay = st.sidebar.number_input( - "Decay MSE weight by ...", - value=0.0, - # min_value=0.0, # leave this out to allow creativity - step=0.05, - help="Subtracts MSE weight by this amount at every step change. MSE weight change stops at zero", - ) - mse_weight_decay_steps = st.sidebar.number_input( - "... every N steps", - value=0, - min_value=0, - step=1, - help="Number of steps to subtract MSE weight. Leave zero for no weight decay", + + use_mse_reg = st.sidebar.checkbox( + "Use MSE regularization", + value=defaults["use_mse_regularization"], + help="Check to add MSE regularization", ) + mse_weight_widget = st.sidebar.empty() + mse_weight_decay_widget = st.sidebar.empty() + mse_weight_decay_steps = st.sidebar.empty() + + if use_mse_reg is True: + mse_weight = mse_weight_widget.number_input( + "MSE weight", + value=defaults["mse_weight"], + # min_value=0.0, # leave this out to allow creativity + step=0.05, + help="Set weights for MSE regularization", + ) + mse_weight_decay = mse_weight_decay_widget.number_input( + "Decay MSE weight by ...", + value=0.0, + # min_value=0.0, # leave this out to allow creativity + step=0.05, + help="Subtracts MSE weight by this amount at every step change. MSE weight change stops at zero", + ) + mse_weight_decay_steps = mse_weight_decay_steps.number_input( + "... every N steps", + value=0, + min_value=0, + step=1, + help="Number of steps to subtract MSE weight. Leave zero for no weight decay", + ) submitted = st.form_submit_button("Run!") # End of form From 5be7c4e028a4ef303a84cee5b13e8c2741915a26 Mon Sep 17 00:00:00 2001 From: Tan Nian Wei Date: Thu, 14 Oct 2021 00:44:54 +0800 Subject: [PATCH 5/9] refactor: trimmed unused code stubs --- mse_vqgan_utils.py | 80 ++-------------------------------------------- 1 file changed, 3 insertions(+), 77 deletions(-) diff --git a/mse_vqgan_utils.py b/mse_vqgan_utils.py index f752154..5649e49 100644 --- a/mse_vqgan_utils.py +++ b/mse_vqgan_utils.py @@ -1,16 +1,9 @@ -import math import torch from torch import nn import torch.nn.functional as F -import io -from omegaconf import OmegaConf -from taming.models import cond_transformer, vqgan -from PIL import Image -from torchvision.transforms import functional as TF import sys -import kornia.augmentation as K -from vqgan_utils import resample, clamp_with_grad, replace_grad, vector_quantize +from vqgan_utils import resample, clamp_with_grad, vector_quantize sys.path.append("./taming-transformers") @@ -27,15 +20,6 @@ def noise_gen(shape): return noise -def one_sided_clip_loss(input, target, labels=None, logit_scale=100): - input_normed = F.normalize(input, dim=-1) - target_normed = F.normalize(target, dim=-1) - logits = input_normed @ target_normed.T * logit_scale - if labels is None: - labels = torch.arange(len(input), device=logits.device) - return F.cross_entropy(logits, labels) - - class MSEMakeCutouts(nn.Module): def __init__(self, cut_size, cutn, cut_pow=1.0, augs=None, noise_fac=0.1): super().__init__() @@ -57,9 +41,9 @@ def forward(self, input): min_size = min(sideX, sideY, self.cut_size) cutouts = [] - min_size_width = min(sideX, sideY) + # min_size_width = min(sideX, sideY) - for ii in range(self.cutn): + for _ in range(self.cutn): size = int( torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size @@ -93,64 +77,6 @@ def forward(self, input): return diff.mean(dim=1).sqrt().mean() -class GaussianBlur2d(nn.Module): - def __init__(self, sigma, window=0, mode="reflect", value=0): - super().__init__() - self.mode = mode - self.value = value - if not window: - window = max(math.ceil((sigma * 6 + 1) / 2) * 2 - 1, 3) - if sigma: - kernel = torch.exp( - -((torch.arange(window) - window // 2) ** 2) / 2 / sigma ** 2 - ) - kernel /= kernel.sum() - else: - kernel = torch.ones([1]) - self.register_buffer("kernel", kernel) - - def forward(self, input): - n, c, h, w = input.shape - input = input.view([n * c, 1, h, w]) - start_pad = (self.kernel.shape[0] - 1) // 2 - end_pad = self.kernel.shape[0] // 2 - input = F.pad( - input, (start_pad, end_pad, start_pad, end_pad), self.mode, self.value - ) - input = F.conv2d(input, self.kernel[None, None, None, :]) - input = F.conv2d(input, self.kernel[None, None, :, None]) - return input.view([n, c, h, w]) - - -class EMATensor(nn.Module): - """implmeneted by Katherine Crowson""" - - def __init__(self, tensor, decay): - super().__init__() - self.tensor = nn.Parameter(tensor) - self.register_buffer("biased", torch.zeros_like(tensor)) - self.register_buffer("average", torch.zeros_like(tensor)) - self.decay = decay - self.register_buffer("accum", torch.tensor(1.0)) - self.update() - - @torch.no_grad() - def update(self): - if not self.training: - raise RuntimeError("update() should only be called during training") - - self.accum *= self.decay - self.biased.mul_(self.decay) - self.biased.add_((1 - self.decay) * self.tensor) - self.average.copy_(self.biased) - self.average.div_(1 - self.accum) - - def forward(self): - if self.training: - return self.tensor - return self.average - - def synth_mse(model, z, is_openimages_f16_8192: bool = False, quantize: bool = True): # z_mask not defined in notebook, appears to be unused # if constraint_regions: From 9a10ab2813b11da174189c172f931cd810fa3dbb Mon Sep 17 00:00:00 2001 From: Tan Nian Wei Date: Thu, 14 Oct 2021 00:49:59 +0800 Subject: [PATCH 6/9] refactor: separated out non MSE features --- mselogic.py | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/mselogic.py b/mselogic.py index 6ce02d0..ac4eb0e 100644 --- a/mselogic.py +++ b/mselogic.py @@ -2,22 +2,13 @@ from PIL import Image import argparse import clip -from vqgan_utils import ( - load_vqgan_model, - MakeCutouts, - parse_prompt, - resize_image, - Prompt, - synth, - checkin, -) +from vqgan_utils import MakeCutouts, parse_prompt, resize_image, Prompt, synth, checkin import torch from torchvision.transforms import functional as TF import torch.nn as nn from torch.nn import functional as F from torch import optim from torchvision import transforms -from mse_vqgan_utils import synth_mse, MSEMakeCutouts, noise_gen import kornia.augmentation as K from logic import VQGANCLIPRun @@ -131,11 +122,8 @@ def model_init(self, init_image: Image.Image = None) -> None: ] f = 2 ** (self.model.decoder.num_resolutions - 1) - self.make_cutouts = MSEMakeCutouts( - cut_size, - self.args.cutn, - cut_pow=self.args.cut_pow, - augs=self.augs if self.use_augs is True else None, + self.make_cutouts = MakeCutouts( + cut_size, self.args.cutn, cut_pow=self.args.cut_pow ) n_toks = self.model.quantize.n_e toksX, toksY = self.args.size[0] // f, self.args.size[1] // f @@ -224,14 +212,7 @@ def model_init(self, init_image: Image.Image = None) -> None: self.pMs.append(Prompt(embed, weight).to(self.device)) def _ascend_txt(self) -> List: - out = synth_mse( - self.model, - self.z, - quantize=True, - is_openimages_f16_8192=True - if self.args.vqgan_checkpoint == "vqgan_openimages_f16_8192.ckpt" - else False, - ) + out = synth(self.model, self.z) cutouts = self.make_cutouts(out) iii = self.perceptor.encode_image(self.normalize(cutouts)).float() From 11ffc12bc90c99506eb5e2f352bdcae53b5b9a13 Mon Sep 17 00:00:00 2001 From: Tan Nian Wei Date: Thu, 14 Oct 2021 00:57:34 +0800 Subject: [PATCH 7/9] refactor: rm'ed temp modules --- logic.py | 251 +++++++++++++++++++++++++++++++++++++++++ mse_vqgan_utils.py | 98 ---------------- mseapp.py | 2 +- mselogic.py | 273 --------------------------------------------- 4 files changed, 252 insertions(+), 372 deletions(-) delete mode 100644 mse_vqgan_utils.py delete mode 100644 mselogic.py diff --git a/logic.py b/logic.py index 8a64fbb..411841e 100644 --- a/logic.py +++ b/logic.py @@ -241,3 +241,254 @@ def iterate(self) -> Tuple[List[float], Image.Image]: # Output stuff useful for humans return [loss.item() for loss in losses], im + + +class MSEVQGANCLIPRun(VQGANCLIPRun): + def __init__( + # Inputs + self, + text_input: str = "the first day of the waters", + vqgan_ckpt: str = "vqgan_imagenet_f16_16384", + num_steps: int = 300, + image_x: int = 300, + image_y: int = 300, + init_image: Optional[Image.Image] = None, + image_prompts: List[Image.Image] = [], + continue_prev_run: bool = False, + seed: Optional[int] = None, + # MSE VQGAN-CLIP options from notebook + use_augs: bool = True, + noise_fac: float = 0.1, + use_noise: Optional[float] = None, + mse_withzeros=True, + # mse_decay_rate=50, + # mse_epoches=5, + # Added options + mse_weight=0.5, + mse_weight_decay=0.1, + mse_weight_decay_steps=50, + ) -> None: + super().__init__() + self.text_input = text_input + self.vqgan_ckpt = vqgan_ckpt + self.num_steps = num_steps + self.image_x = image_x + self.image_y = image_y + self.init_image = init_image + self.image_prompts = image_prompts + self.continue_prev_run = continue_prev_run + self.seed = seed + + # Setup ------------------------------------------------------------------------------ + # Split text by "|" symbol + texts = [phrase.strip() for phrase in text_input.split("|")] + if texts == [""]: + texts = [] + + # Leaving most of this untouched + self.args = argparse.Namespace( + prompts=texts, + image_prompts=image_prompts, + noise_prompt_seeds=[], + noise_prompt_weights=[], + size=[int(image_x), int(image_y)], + init_image=init_image, + init_weight=mse_weight, + clip_model="ViT-B/32", + vqgan_config=f"assets/{vqgan_ckpt}.yaml", + vqgan_checkpoint=f"assets/{vqgan_ckpt}.ckpt", + step_size=0.05, + cutn=64, + cut_pow=1.0, + display_freq=50, + seed=seed, + ) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.device = device + print("Using device:", device) + + # TODO: MSE regularized options here + self.iterate_counter = 0 + self.use_augs = use_augs + self.noise_fac = noise_fac + self.use_noise = use_noise + self.mse_withzeros = mse_withzeros + self.mse_weight_decay = mse_weight_decay + self.mse_weight_decay_steps = mse_weight_decay_steps + + self.mse_weight = self.args.init_weight + self.init_mse_weight = self.mse_weight + + def model_init(self, init_image: Image.Image = None) -> None: + cut_size = self.perceptor.visual.input_resolution + + if self.args.vqgan_checkpoint == "vqgan_openimages_f16_8192.ckpt": + self.e_dim = 256 + self.n_toks = self.model.quantize.n_embed + self.z_min = self.model.quantize.embed.weight.min(dim=0).values[ + None, :, None, None + ] + self.z_max = self.model.quantize.embed.weight.max(dim=0).values[ + None, :, None, None + ] + else: + self.e_dim = self.model.quantize.e_dim + self.n_toks = self.model.quantize.n_e + self.z_min = self.model.quantize.embedding.weight.min(dim=0).values[ + None, :, None, None + ] + self.z_max = self.model.quantize.embedding.weight.max(dim=0).values[ + None, :, None, None + ] + + f = 2 ** (self.model.decoder.num_resolutions - 1) + self.make_cutouts = MakeCutouts( + cut_size, self.args.cutn, cut_pow=self.args.cut_pow + ) + n_toks = self.model.quantize.n_e + toksX, toksY = self.args.size[0] // f, self.args.size[1] // f + sideX, sideY = toksX * f, toksY * f # notebook used 16 instead of f here + + if self.seed is not None: + torch.manual_seed(self.seed) + else: + self.seed = torch.seed() # Trigger a seed, retrieve the utilized seed + + # Initialization order: continue_prev_im, init_image, then only random init + if init_image is not None: + init_image = init_image.resize((sideX, sideY), Image.LANCZOS) + init_image = TF.to_tensor(init_image) + + if self.args.use_noise: + init_image = init_image + self.args.use_noise * torch.randn_like( + init_image + ) + + self.z, *_ = self.model.encode( + TF.to_tensor(init_image).to(self.device).unsqueeze(0) * 2 - 1 + ) + elif self.args.init_image: + pil_image = self.args.init_image + pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) + pil_image = TF.to_tensor(pil_image) + + if self.args.use_noise: + pil_image = pil_image + self.args.use_noise * torch.randn_like( + pil_image + ) + + self.z, *_ = self.model.encode( + TF.to_tensor(pil_image).to(self.device).unsqueeze(0) * 2 - 1 + ) + else: + one_hot = F.one_hot( + torch.randint(n_toks, [toksY * toksX], device=self.device), n_toks + ).float() + + if self.args.vqgan_checkpoint == "vqgan_openimages_f16_8192.ckpt": + self.z = one_hot @ self.model.quantize.embed.weight + else: + self.z = one_hot @ self.model.quantize.embedding.weight + + self.z = self.z.view([-1, toksY, toksX, self.e_dim]).permute(0, 3, 1, 2) + + if self.mse_withzeros and not self.args.init_image: + self.z_orig = torch.zeros_like(self.z) + else: + self.z_orig = self.z.clone() + + self.z.requires_grad_(True) + self.opt = optim.Adam([self.z], lr=self.args.step_size) + + self.normalize = transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ) + + self.pMs = [] + + for prompt in self.args.prompts: + txt, weight, stop = parse_prompt(prompt) + embed = self.perceptor.encode_text( + clip.tokenize(txt).to(self.device) + ).float() + self.pMs.append(Prompt(embed, weight, stop).to(self.device)) + + for uploaded_image in self.args.image_prompts: + # path, weight, stop = parse_prompt(prompt) + # img = resize_image(Image.open(fetch(path)).convert("RGB"), (sideX, sideY)) + img = resize_image(uploaded_image.convert("RGB"), (sideX, sideY)) + batch = self.make_cutouts(TF.to_tensor(img).unsqueeze(0).to(self.device)) + embed = self.perceptor.encode_image(self.normalize(batch)).float() + self.pMs.append(Prompt(embed, weight, stop).to(self.device)) + + for seed, weight in zip( + self.args.noise_prompt_seeds, self.args.noise_prompt_weights + ): + gen = torch.Generator().manual_seed(seed) + embed = torch.empty([1, self.perceptor.visual.output_dim]).normal_( + generator=gen + ) + self.pMs.append(Prompt(embed, weight).to(self.device)) + + def _ascend_txt(self) -> List: + out = synth(self.model, self.z) + + cutouts = self.make_cutouts(out) + iii = self.perceptor.encode_image(self.normalize(cutouts)).float() + + result = [] + + if self.args.init_weight: + result.append(F.mse_loss(self.z, self.z_orig) * self.mse_weight / 2) + + with torch.no_grad(): + # if not the first step + # and is time for step change + # and both weight decay steps and magnitude are nonzero + # and MSE isn't zero already + if ( + self.iterate_counter > 0 + and self.iterate_counter % self.mse_weight_decay_steps == 0 + and self.mse_weight_decay != 0 + and self.mse_weight_decay_steps != 0 + and self.mse_weight != 0 + ): + self.mse_weight = self.mse_weight - self.mse_weight_decay + + # Don't allow changing sign + # Basically, caps MSE at zero if decreasing from positive + # But, also prevents MSE from becoming positive if -MSE intended + if self.init_mse_weight > 0: + self.mse_weight = max(self.mse_weight, 0) + else: + self.mse_weight = min(self.mse_weight, 0) + + print(f"updated mse weight: {self.mse_weight}") + + for prompt in self.pMs: + result.append(prompt(iii)) + + return result + + def iterate(self) -> Tuple[List[float], Image.Image]: + # Forward prop + self.opt.zero_grad() + losses = self._ascend_txt() + + # Grab an image + im: Image.Image = checkin(self.model, self.z) + + # Backprop + loss = sum(losses) + loss.backward() + self.opt.step() + # with torch.no_grad(): + # self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max)) + + # Advance iteration counter + self.iterate_counter += 1 + + # Output stuff useful for humans + return [loss.item() for loss in losses], im diff --git a/mse_vqgan_utils.py b/mse_vqgan_utils.py deleted file mode 100644 index 5649e49..0000000 --- a/mse_vqgan_utils.py +++ /dev/null @@ -1,98 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F -import sys - -from vqgan_utils import resample, clamp_with_grad, vector_quantize - -sys.path.append("./taming-transformers") - - -def noise_gen(shape): - n, c, h, w = shape - noise = torch.zeros([n, c, 1, 1]) - for i in reversed(range(5)): - h_cur, w_cur = h // 2 ** i, w // 2 ** i - noise = F.interpolate( - noise, (h_cur, w_cur), mode="bicubic", align_corners=False - ) - noise += torch.randn([n, c, h_cur, w_cur]) / 5 - return noise - - -class MSEMakeCutouts(nn.Module): - def __init__(self, cut_size, cutn, cut_pow=1.0, augs=None, noise_fac=0.1): - super().__init__() - self.cut_size = cut_size - self.cutn = cutn - self.cut_pow = cut_pow - self.augs = augs - self.noise_fac = noise_fac - - self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) - self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) - - def set_cut_pow(self, cut_pow): - self.cut_pow = cut_pow - - def forward(self, input): - sideY, sideX = input.shape[2:4] - max_size = min(sideX, sideY) - min_size = min(sideX, sideY, self.cut_size) - cutouts = [] - - # min_size_width = min(sideX, sideY) - - for _ in range(self.cutn): - - size = int( - torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size - ) - - offsetx = torch.randint(0, sideX - size + 1, ()) - offsety = torch.randint(0, sideY - size + 1, ()) - cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] - cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) - - cutouts = torch.cat(cutouts, dim=0) - - if self.augs is not None: - cutouts = self.augs(cutouts) - - if self.noise_fac: - facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_( - 0, self.noise_fac - ) - cutouts = cutouts + facs * torch.randn_like(cutouts) - - return clamp_with_grad(cutouts, 0, 1) - - -class TVLoss(nn.Module): - def forward(self, input): - input = F.pad(input, (0, 1, 0, 1), "replicate") - x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] - y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] - diff = x_diff ** 2 + y_diff ** 2 + 1e-8 - return diff.mean(dim=1).sqrt().mean() - - -def synth_mse(model, z, is_openimages_f16_8192: bool = False, quantize: bool = True): - # z_mask not defined in notebook, appears to be unused - # if constraint_regions: - # z = replace_grad(z, z * z_mask) - - if quantize: - if is_openimages_f16_8192: - z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim( - 3, 1 - ) - else: - z_q = vector_quantize( - z.movedim(1, 3), model.quantize.embedding.weight - ).movedim(3, 1) - - else: - z_q = z.model - - return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1) diff --git a/mseapp.py b/mseapp.py index d1bf330..c475f3e 100644 --- a/mseapp.py +++ b/mseapp.py @@ -21,7 +21,7 @@ from omegaconf import OmegaConf import imageio import numpy as np -from mselogic import MSEVQGANCLIPRun +from logic import MSEVQGANCLIPRun def generate_image( diff --git a/mselogic.py b/mselogic.py deleted file mode 100644 index ac4eb0e..0000000 --- a/mselogic.py +++ /dev/null @@ -1,273 +0,0 @@ -from typing import Optional, List, Tuple -from PIL import Image -import argparse -import clip -from vqgan_utils import MakeCutouts, parse_prompt, resize_image, Prompt, synth, checkin -import torch -from torchvision.transforms import functional as TF -import torch.nn as nn -from torch.nn import functional as F -from torch import optim -from torchvision import transforms -import kornia.augmentation as K -from logic import VQGANCLIPRun - - -class MSEVQGANCLIPRun(VQGANCLIPRun): - def __init__( - # Inputs - self, - text_input: str = "the first day of the waters", - vqgan_ckpt: str = "vqgan_imagenet_f16_16384", - num_steps: int = 300, - image_x: int = 300, - image_y: int = 300, - init_image: Optional[Image.Image] = None, - image_prompts: List[Image.Image] = [], - continue_prev_run: bool = False, - seed: Optional[int] = None, - # MSE VQGAN-CLIP options from notebook - use_augs: bool = True, - noise_fac: float = 0.1, - use_noise: Optional[float] = None, - mse_withzeros=True, - # mse_decay_rate=50, - # mse_epoches=5, - # Added options - mse_weight=0.5, - mse_weight_decay=0.1, - mse_weight_decay_steps=50, - ) -> None: - super().__init__() - self.text_input = text_input - self.vqgan_ckpt = vqgan_ckpt - self.num_steps = num_steps - self.image_x = image_x - self.image_y = image_y - self.init_image = init_image - self.image_prompts = image_prompts - self.continue_prev_run = continue_prev_run - self.seed = seed - - # Setup ------------------------------------------------------------------------------ - # Split text by "|" symbol - texts = [phrase.strip() for phrase in text_input.split("|")] - if texts == [""]: - texts = [] - - # Leaving most of this untouched - self.args = argparse.Namespace( - prompts=texts, - image_prompts=image_prompts, - noise_prompt_seeds=[], - noise_prompt_weights=[], - size=[int(image_x), int(image_y)], - init_image=init_image, - init_weight=mse_weight, - clip_model="ViT-B/32", - vqgan_config=f"assets/{vqgan_ckpt}.yaml", - vqgan_checkpoint=f"assets/{vqgan_ckpt}.ckpt", - step_size=0.05, - cutn=64, - cut_pow=1.0, - display_freq=50, - seed=seed, - ) - - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self.device = device - print("Using device:", device) - - # TODO: MSE regularized options here - self.iterate_counter = 0 - self.use_augs = use_augs - self.noise_fac = noise_fac - self.use_noise = use_noise - self.mse_withzeros = mse_withzeros - self.mse_weight_decay = mse_weight_decay - self.mse_weight_decay_steps = mse_weight_decay_steps - - self.mse_weight = self.args.init_weight - self.init_mse_weight = self.mse_weight - - self.augs = nn.Sequential( - K.RandomHorizontalFlip(p=0.5), - K.RandomAffine( - degrees=30, translate=0.1, p=0.8, padding_mode="border" - ), # padding_mode=2 - K.RandomPerspective(0.2, p=0.4), - K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), - ) - - def model_init(self, init_image: Image.Image = None) -> None: - cut_size = self.perceptor.visual.input_resolution - - if self.args.vqgan_checkpoint == "vqgan_openimages_f16_8192.ckpt": - self.e_dim = 256 - self.n_toks = self.model.quantize.n_embed - self.z_min = self.model.quantize.embed.weight.min(dim=0).values[ - None, :, None, None - ] - self.z_max = self.model.quantize.embed.weight.max(dim=0).values[ - None, :, None, None - ] - else: - self.e_dim = self.model.quantize.e_dim - self.n_toks = self.model.quantize.n_e - self.z_min = self.model.quantize.embedding.weight.min(dim=0).values[ - None, :, None, None - ] - self.z_max = self.model.quantize.embedding.weight.max(dim=0).values[ - None, :, None, None - ] - - f = 2 ** (self.model.decoder.num_resolutions - 1) - self.make_cutouts = MakeCutouts( - cut_size, self.args.cutn, cut_pow=self.args.cut_pow - ) - n_toks = self.model.quantize.n_e - toksX, toksY = self.args.size[0] // f, self.args.size[1] // f - sideX, sideY = toksX * f, toksY * f # notebook used 16 instead of f here - - if self.seed is not None: - torch.manual_seed(self.seed) - else: - self.seed = torch.seed() # Trigger a seed, retrieve the utilized seed - - # Initialization order: continue_prev_im, init_image, then only random init - if init_image is not None: - init_image = init_image.resize((sideX, sideY), Image.LANCZOS) - init_image = TF.to_tensor(init_image) - - if self.args.use_noise: - init_image = init_image + self.args.use_noise * torch.randn_like( - init_image - ) - - self.z, *_ = self.model.encode( - TF.to_tensor(init_image).to(self.device).unsqueeze(0) * 2 - 1 - ) - elif self.args.init_image: - pil_image = self.args.init_image - pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) - pil_image = TF.to_tensor(pil_image) - - if self.args.use_noise: - pil_image = pil_image + self.args.use_noise * torch.randn_like( - pil_image - ) - - self.z, *_ = self.model.encode( - TF.to_tensor(pil_image).to(self.device).unsqueeze(0) * 2 - 1 - ) - else: - one_hot = F.one_hot( - torch.randint(n_toks, [toksY * toksX], device=self.device), n_toks - ).float() - - if self.args.vqgan_checkpoint == "vqgan_openimages_f16_8192.ckpt": - self.z = one_hot @ self.model.quantize.embed.weight - else: - self.z = one_hot @ self.model.quantize.embedding.weight - - self.z = self.z.view([-1, toksY, toksX, self.e_dim]).permute(0, 3, 1, 2) - - if self.mse_withzeros and not self.args.init_image: - self.z_orig = torch.zeros_like(self.z) - else: - self.z_orig = self.z.clone() - - self.z.requires_grad_(True) - self.opt = optim.Adam([self.z], lr=self.args.step_size) - - self.normalize = transforms.Normalize( - mean=[0.48145466, 0.4578275, 0.40821073], - std=[0.26862954, 0.26130258, 0.27577711], - ) - - self.pMs = [] - - for prompt in self.args.prompts: - txt, weight, stop = parse_prompt(prompt) - embed = self.perceptor.encode_text( - clip.tokenize(txt).to(self.device) - ).float() - self.pMs.append(Prompt(embed, weight, stop).to(self.device)) - - for uploaded_image in self.args.image_prompts: - # path, weight, stop = parse_prompt(prompt) - # img = resize_image(Image.open(fetch(path)).convert("RGB"), (sideX, sideY)) - img = resize_image(uploaded_image.convert("RGB"), (sideX, sideY)) - batch = self.make_cutouts(TF.to_tensor(img).unsqueeze(0).to(self.device)) - embed = self.perceptor.encode_image(self.normalize(batch)).float() - self.pMs.append(Prompt(embed, weight, stop).to(self.device)) - - for seed, weight in zip( - self.args.noise_prompt_seeds, self.args.noise_prompt_weights - ): - gen = torch.Generator().manual_seed(seed) - embed = torch.empty([1, self.perceptor.visual.output_dim]).normal_( - generator=gen - ) - self.pMs.append(Prompt(embed, weight).to(self.device)) - - def _ascend_txt(self) -> List: - out = synth(self.model, self.z) - - cutouts = self.make_cutouts(out) - iii = self.perceptor.encode_image(self.normalize(cutouts)).float() - - result = [] - - if self.args.init_weight: - result.append(F.mse_loss(self.z, self.z_orig) * self.mse_weight / 2) - - with torch.no_grad(): - # if not the first step - # and is time for step change - # and both weight decay steps and magnitude are nonzero - # and MSE isn't zero already - if ( - self.iterate_counter > 0 - and self.iterate_counter % self.mse_weight_decay_steps == 0 - and self.mse_weight_decay != 0 - and self.mse_weight_decay_steps != 0 - and self.mse_weight != 0 - ): - self.mse_weight = self.mse_weight - self.mse_weight_decay - - # Don't allow changing sign - # Basically, caps MSE at zero if decreasing from positive - # But, also prevents MSE from becoming positive if -MSE intended - if self.init_mse_weight > 0: - self.mse_weight = max(self.mse_weight, 0) - else: - self.mse_weight = min(self.mse_weight, 0) - - print(f"updated mse weight: {self.mse_weight}") - - for prompt in self.pMs: - result.append(prompt(iii)) - - return result - - def iterate(self) -> Tuple[List[float], Image.Image]: - # Forward prop - self.opt.zero_grad() - losses = self._ascend_txt() - - # Grab an image - im: Image.Image = checkin(self.model, self.z) - - # Backprop - loss = sum(losses) - loss.backward() - self.opt.step() - # with torch.no_grad(): - # self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max)) - - # Advance iteration counter - self.iterate_counter += 1 - - # Output stuff useful for humans - return [loss.item() for loss in losses], im From c8a35c032ef9653c508b103603ee767e4e7f0953 Mon Sep 17 00:00:00 2001 From: Tan Nian Wei Date: Thu, 14 Oct 2021 01:12:17 +0800 Subject: [PATCH 8/9] refactor: merged all non-UI code --- defaults.yaml | 3 +- logic.py | 250 +++++--------------------------------------------- mseapp.py | 8 +- 3 files changed, 26 insertions(+), 235 deletions(-) diff --git a/defaults.yaml b/defaults.yaml index 793cd69..16c4902 100644 --- a/defaults.yaml +++ b/defaults.yaml @@ -7,4 +7,5 @@ seed: 0 use_starting_image: false use_image_prompts: false continue_prev_run: false -mse_weight: 0.0 \ No newline at end of file +mse_weight: 0.0 +use_mse_regularization: false \ No newline at end of file diff --git a/logic.py b/logic.py index 411841e..741ea5d 100644 --- a/logic.py +++ b/logic.py @@ -66,6 +66,13 @@ def __init__( image_prompts: List[Image.Image] = [], continue_prev_run: bool = False, seed: Optional[int] = None, + mse_weight=0.5, + mse_weight_decay=0.1, + mse_weight_decay_steps=50, + # use_augs: bool = True, + # noise_fac: float = 0.1, + # use_noise: Optional[float] = None, + # mse_withzeros=True, ## **kwargs, # Use this to receive Streamlit objects ## Call from main UI ) -> None: super().__init__() @@ -93,7 +100,7 @@ def __init__( noise_prompt_weights=[], size=[int(image_x), int(image_y)], init_image=init_image, - init_weight=0.0, + init_weight=mse_weight, # clip.available_models() # ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] # Visual Transformer seems to be the smallest @@ -111,6 +118,16 @@ def __init__( self.device = device print("Using device:", device) + self.iterate_counter = 0 + # self.use_augs = use_augs + # self.noise_fac = noise_fac + # self.use_noise = use_noise + # self.mse_withzeros = mse_withzeros + self.init_mse_weight = mse_weight + self.mse_weight = mse_weight + self.mse_weight_decay = mse_weight_decay + self.mse_weight_decay_steps = mse_weight_decay_steps + def load_model( self, prev_model: nn.Module = None, prev_perceptor: nn.Module = None ) -> Optional[Tuple[nn.Module, nn.Module]]: @@ -217,232 +234,9 @@ def _ascend_txt(self) -> List: result = [] if self.args.init_weight: - result.append(F.mse_loss(self.z, self.z_orig) * self.args.init_weight / 2) - - for prompt in self.pMs: - result.append(prompt(iii)) - - return result - - def iterate(self) -> Tuple[List[float], Image.Image]: - # Forward prop - self.opt.zero_grad() - losses = self._ascend_txt() - - # Grab an image - im: Image.Image = checkin(self.model, self.z) - - # Backprop - loss = sum(losses) - loss.backward() - self.opt.step() - with torch.no_grad(): - self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max)) - - # Output stuff useful for humans - return [loss.item() for loss in losses], im - - -class MSEVQGANCLIPRun(VQGANCLIPRun): - def __init__( - # Inputs - self, - text_input: str = "the first day of the waters", - vqgan_ckpt: str = "vqgan_imagenet_f16_16384", - num_steps: int = 300, - image_x: int = 300, - image_y: int = 300, - init_image: Optional[Image.Image] = None, - image_prompts: List[Image.Image] = [], - continue_prev_run: bool = False, - seed: Optional[int] = None, - # MSE VQGAN-CLIP options from notebook - use_augs: bool = True, - noise_fac: float = 0.1, - use_noise: Optional[float] = None, - mse_withzeros=True, - # mse_decay_rate=50, - # mse_epoches=5, - # Added options - mse_weight=0.5, - mse_weight_decay=0.1, - mse_weight_decay_steps=50, - ) -> None: - super().__init__() - self.text_input = text_input - self.vqgan_ckpt = vqgan_ckpt - self.num_steps = num_steps - self.image_x = image_x - self.image_y = image_y - self.init_image = init_image - self.image_prompts = image_prompts - self.continue_prev_run = continue_prev_run - self.seed = seed - - # Setup ------------------------------------------------------------------------------ - # Split text by "|" symbol - texts = [phrase.strip() for phrase in text_input.split("|")] - if texts == [""]: - texts = [] - - # Leaving most of this untouched - self.args = argparse.Namespace( - prompts=texts, - image_prompts=image_prompts, - noise_prompt_seeds=[], - noise_prompt_weights=[], - size=[int(image_x), int(image_y)], - init_image=init_image, - init_weight=mse_weight, - clip_model="ViT-B/32", - vqgan_config=f"assets/{vqgan_ckpt}.yaml", - vqgan_checkpoint=f"assets/{vqgan_ckpt}.ckpt", - step_size=0.05, - cutn=64, - cut_pow=1.0, - display_freq=50, - seed=seed, - ) - - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self.device = device - print("Using device:", device) - - # TODO: MSE regularized options here - self.iterate_counter = 0 - self.use_augs = use_augs - self.noise_fac = noise_fac - self.use_noise = use_noise - self.mse_withzeros = mse_withzeros - self.mse_weight_decay = mse_weight_decay - self.mse_weight_decay_steps = mse_weight_decay_steps - - self.mse_weight = self.args.init_weight - self.init_mse_weight = self.mse_weight - - def model_init(self, init_image: Image.Image = None) -> None: - cut_size = self.perceptor.visual.input_resolution - - if self.args.vqgan_checkpoint == "vqgan_openimages_f16_8192.ckpt": - self.e_dim = 256 - self.n_toks = self.model.quantize.n_embed - self.z_min = self.model.quantize.embed.weight.min(dim=0).values[ - None, :, None, None - ] - self.z_max = self.model.quantize.embed.weight.max(dim=0).values[ - None, :, None, None - ] - else: - self.e_dim = self.model.quantize.e_dim - self.n_toks = self.model.quantize.n_e - self.z_min = self.model.quantize.embedding.weight.min(dim=0).values[ - None, :, None, None - ] - self.z_max = self.model.quantize.embedding.weight.max(dim=0).values[ - None, :, None, None - ] - - f = 2 ** (self.model.decoder.num_resolutions - 1) - self.make_cutouts = MakeCutouts( - cut_size, self.args.cutn, cut_pow=self.args.cut_pow - ) - n_toks = self.model.quantize.n_e - toksX, toksY = self.args.size[0] // f, self.args.size[1] // f - sideX, sideY = toksX * f, toksY * f # notebook used 16 instead of f here - - if self.seed is not None: - torch.manual_seed(self.seed) - else: - self.seed = torch.seed() # Trigger a seed, retrieve the utilized seed - - # Initialization order: continue_prev_im, init_image, then only random init - if init_image is not None: - init_image = init_image.resize((sideX, sideY), Image.LANCZOS) - init_image = TF.to_tensor(init_image) - - if self.args.use_noise: - init_image = init_image + self.args.use_noise * torch.randn_like( - init_image - ) - - self.z, *_ = self.model.encode( - TF.to_tensor(init_image).to(self.device).unsqueeze(0) * 2 - 1 - ) - elif self.args.init_image: - pil_image = self.args.init_image - pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) - pil_image = TF.to_tensor(pil_image) - - if self.args.use_noise: - pil_image = pil_image + self.args.use_noise * torch.randn_like( - pil_image - ) - - self.z, *_ = self.model.encode( - TF.to_tensor(pil_image).to(self.device).unsqueeze(0) * 2 - 1 - ) - else: - one_hot = F.one_hot( - torch.randint(n_toks, [toksY * toksX], device=self.device), n_toks - ).float() - - if self.args.vqgan_checkpoint == "vqgan_openimages_f16_8192.ckpt": - self.z = one_hot @ self.model.quantize.embed.weight - else: - self.z = one_hot @ self.model.quantize.embedding.weight - - self.z = self.z.view([-1, toksY, toksX, self.e_dim]).permute(0, 3, 1, 2) - - if self.mse_withzeros and not self.args.init_image: - self.z_orig = torch.zeros_like(self.z) - else: - self.z_orig = self.z.clone() - - self.z.requires_grad_(True) - self.opt = optim.Adam([self.z], lr=self.args.step_size) - - self.normalize = transforms.Normalize( - mean=[0.48145466, 0.4578275, 0.40821073], - std=[0.26862954, 0.26130258, 0.27577711], - ) - - self.pMs = [] - - for prompt in self.args.prompts: - txt, weight, stop = parse_prompt(prompt) - embed = self.perceptor.encode_text( - clip.tokenize(txt).to(self.device) - ).float() - self.pMs.append(Prompt(embed, weight, stop).to(self.device)) - - for uploaded_image in self.args.image_prompts: - # path, weight, stop = parse_prompt(prompt) - # img = resize_image(Image.open(fetch(path)).convert("RGB"), (sideX, sideY)) - img = resize_image(uploaded_image.convert("RGB"), (sideX, sideY)) - batch = self.make_cutouts(TF.to_tensor(img).unsqueeze(0).to(self.device)) - embed = self.perceptor.encode_image(self.normalize(batch)).float() - self.pMs.append(Prompt(embed, weight, stop).to(self.device)) - - for seed, weight in zip( - self.args.noise_prompt_seeds, self.args.noise_prompt_weights - ): - gen = torch.Generator().manual_seed(seed) - embed = torch.empty([1, self.perceptor.visual.output_dim]).normal_( - generator=gen - ) - self.pMs.append(Prompt(embed, weight).to(self.device)) - - def _ascend_txt(self) -> List: - out = synth(self.model, self.z) - - cutouts = self.make_cutouts(out) - iii = self.perceptor.encode_image(self.normalize(cutouts)).float() - - result = [] - - if self.args.init_weight: - result.append(F.mse_loss(self.z, self.z_orig) * self.mse_weight / 2) + result.append(F.mse_loss(self.z, self.z_orig) * self.init_mse_weight / 2) + # MSE regularization scheduler with torch.no_grad(): # if not the first step # and is time for step change @@ -484,8 +278,8 @@ def iterate(self) -> Tuple[List[float], Image.Image]: loss = sum(losses) loss.backward() self.opt.step() - # with torch.no_grad(): - # self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max)) + with torch.no_grad(): + self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max)) # Advance iteration counter self.iterate_counter += 1 diff --git a/mseapp.py b/mseapp.py index c475f3e..ddd8f6d 100644 --- a/mseapp.py +++ b/mseapp.py @@ -21,7 +21,7 @@ from omegaconf import OmegaConf import imageio import numpy as np -from logic import MSEVQGANCLIPRun +from logic import VQGANCLIPRun def generate_image( @@ -40,7 +40,7 @@ def generate_image( ) -> None: ### Init ------------------------------------------------------------------- - run = MSEVQGANCLIPRun( + run = VQGANCLIPRun( text_input=text_input, vqgan_ckpt=vqgan_ckpt, num_steps=num_steps, @@ -50,10 +50,6 @@ def generate_image( init_image=init_image, image_prompts=image_prompts, continue_prev_run=continue_prev_run, - use_augs=True, - noise_fac=0.1, - use_noise=None, - mse_withzeros=True, mse_weight=mse_weight, mse_weight_decay=mse_weight_decay, mse_weight_decay_steps=mse_weight_decay_steps, From 5c90660f3ca948445fde29cba412ca445558c8cc Mon Sep 17 00:00:00 2001 From: Tan Nian Wei Date: Thu, 14 Oct 2021 01:19:13 +0800 Subject: [PATCH 9/9] refactor: merged all MSE code to existing modules --- app.py | 52 ++++++ mseapp.py | 460 ------------------------------------------------------ 2 files changed, 52 insertions(+), 460 deletions(-) delete mode 100644 mseapp.py diff --git a/app.py b/app.py index 33b36a3..790442c 100644 --- a/app.py +++ b/app.py @@ -34,6 +34,9 @@ def generate_image( image_prompts: List[Image.Image] = [], continue_prev_run: bool = False, seed: Optional[int] = None, + mse_weight: float = 0, + mse_weight_decay: float = 0, + mse_weight_decay_steps: int = 0, ) -> None: ### Init ------------------------------------------------------------------- @@ -47,6 +50,9 @@ def generate_image( init_image=init_image, image_prompts=image_prompts, continue_prev_run=continue_prev_run, + mse_weight=mse_weight, + mse_weight_decay=mse_weight_decay, + mse_weight_decay_steps=mse_weight_decay_steps, ) ### Load model ------------------------------------------------------------- @@ -171,6 +177,9 @@ def generate_image( "vqgan_ckpt": vqgan_ckpt, "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), + "mse_weight": mse_weight, + "mse_weight_decay": mse_weight_decay, + "mse_weight_decay_steps": mse_weight_decay_steps, }, f, indent=4, @@ -228,6 +237,9 @@ def generate_image( "vqgan_ckpt": vqgan_ckpt, "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), + "mse_weight": mse_weight, + "mse_weight_decay": mse_weight_decay, + "mse_weight_decay_steps": mse_weight_decay_steps, }, f, indent=4, @@ -356,6 +368,43 @@ def generate_image( value=defaults["continue_prev_run"], help="Use existing image and existing weights for the next run. If yes, ignores 'Use starting image'", ) + + use_mse_reg = st.sidebar.checkbox( + "Use MSE regularization", + value=defaults["use_mse_regularization"], + help="Check to add MSE regularization", + ) + mse_weight_widget = st.sidebar.empty() + mse_weight_decay_widget = st.sidebar.empty() + mse_weight_decay_steps = st.sidebar.empty() + + if use_mse_reg is True: + mse_weight = mse_weight_widget.number_input( + "MSE weight", + value=defaults["mse_weight"], + # min_value=0.0, # leave this out to allow creativity + step=0.05, + help="Set weights for MSE regularization", + ) + mse_weight_decay = mse_weight_decay_widget.number_input( + "Decay MSE weight by ...", + value=0.0, + # min_value=0.0, # leave this out to allow creativity + step=0.05, + help="Subtracts MSE weight by this amount at every step change. MSE weight change stops at zero", + ) + mse_weight_decay_steps = mse_weight_decay_steps.number_input( + "... every N steps", + value=0, + min_value=0, + step=1, + help="Number of steps to subtract MSE weight. Leave zero for no weight decay", + ) + else: + mse_weight = 0 + mse_weight_decay = 0 + mse_weight_decay_steps = 0 + submitted = st.form_submit_button("Run!") # End of form @@ -408,6 +457,9 @@ def generate_image( init_image=init_image, image_prompts=image_prompts, continue_prev_run=continue_prev_run, + mse_weight=mse_weight, + mse_weight_decay=mse_weight_decay, + mse_weight_decay_steps=mse_weight_decay_steps, ) vid_display_slot.video("temp.mp4") # debug_slot.write(st.session_state) # DEBUG diff --git a/mseapp.py b/mseapp.py deleted file mode 100644 index ddd8f6d..0000000 --- a/mseapp.py +++ /dev/null @@ -1,460 +0,0 @@ -""" -This script is organized like so: -+ `if __name__ == "__main__" sets up the Streamlit UI elements -+ `generate_image` houses interactions between UI and the CLIP image -generation models -+ Core model code is abstracted in `logic.py` and imported in `generate_image` -""" -import streamlit as st -from pathlib import Path -import sys -import datetime -import shutil -import json -import os -import base64 - -sys.path.append("./taming-transformers") - -from PIL import Image -from typing import Optional, List -from omegaconf import OmegaConf -import imageio -import numpy as np -from logic import VQGANCLIPRun - - -def generate_image( - text_input: str = "the first day of the waters", - vqgan_ckpt: str = "vqgan_imagenet_f16_16384", - num_steps: int = 300, - image_x: int = 300, - image_y: int = 300, - init_image: Optional[Image.Image] = None, - image_prompts: List[Image.Image] = [], - continue_prev_run: bool = False, - seed: Optional[int] = None, - mse_weight: float = 0, - mse_weight_decay: float = 0, - mse_weight_decay_steps: int = 0, -) -> None: - - ### Init ------------------------------------------------------------------- - run = VQGANCLIPRun( - text_input=text_input, - vqgan_ckpt=vqgan_ckpt, - num_steps=num_steps, - image_x=image_x, - image_y=image_y, - seed=seed, - init_image=init_image, - image_prompts=image_prompts, - continue_prev_run=continue_prev_run, - mse_weight=mse_weight, - mse_weight_decay=mse_weight_decay, - mse_weight_decay_steps=mse_weight_decay_steps, - ) - - ### Load model ------------------------------------------------------------- - - if continue_prev_run is True: - run.load_model( - prev_model=st.session_state["model"], - prev_perceptor=st.session_state["perceptor"], - ) - prev_run_id = st.session_state["run_id"] - - else: - # Remove the cache first! CUDA out of memory - if "model" in st.session_state: - del st.session_state["model"] - - if "perceptor" in st.session_state: - del st.session_state["perceptor"] - - st.session_state["model"], st.session_state["perceptor"] = run.load_model() - prev_run_id = None - - # Generate random run ID - # Used to link runs linked w/ continue_prev_run - # ref: https://stackoverflow.com/a/42703382/13095028 - # Use URL and filesystem safe version since we're using this as a folder name - run_id = st.session_state["run_id"] = base64.urlsafe_b64encode( - os.urandom(6) - ).decode("ascii") - - run_start_dt = datetime.datetime.now() - - ### Model init ------------------------------------------------------------- - if continue_prev_run is True: - run.model_init(init_image=st.session_state["prev_im"]) - elif init_image is not None: - run.model_init(init_image=init_image) - else: - run.model_init() - - ### Iterate ---------------------------------------------------------------- - step_counter = 0 - frames = [] - - try: - # Try block catches st.script_runner.StopExecution, no need of a dedicated stop button - # Reason is st.form is meant to be self-contained either within sidebar, or in main body - # The way the form is implemented in this app splits the form across both regions - # This is intended to prevent the model settings from crowding the main body - # However, touching any button resets the app state, making it impossible to - # implement a stop button that can still dump output - # Thankfully there's a built-in stop button :) - while True: - # While loop to accomodate running predetermined steps or running indefinitely - status_text.text(f"Running step {step_counter}") - - _, im = run.iterate() - - if num_steps > 0: # skip when num_steps = -1 - step_progress_bar.progress((step_counter + 1) / num_steps) - else: - step_progress_bar.progress(100) - - # At every step, display and save image - im_display_slot.image(im, caption="Output image", output_format="PNG") - st.session_state["prev_im"] = im - - # ref: https://stackoverflow.com/a/33117447/13095028 - # im_byte_arr = io.BytesIO() - # im.save(im_byte_arr, format="JPEG") - # frames.append(im_byte_arr.getvalue()) # read() - frames.append(np.asarray(im)) - - step_counter += 1 - - if (step_counter == num_steps) and num_steps > 0: - break - - # Stitch into video using imageio - writer = imageio.get_writer("temp.mp4", fps=24) - for frame in frames: - writer.append_data(frame) - writer.close() - - # Save to output folder if run completed - runoutputdir = outputdir / ( - run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id - ) - runoutputdir.mkdir() - - # Save final image - im.save(runoutputdir / "output.PNG", format="PNG") - - # Save init image - if init_image is not None: - init_image.save(runoutputdir / "init-image.JPEG", format="JPEG") - - # Save image prompts - for count, image_prompt in enumerate(image_prompts): - image_prompt.save( - runoutputdir / f"image-prompt-{count}.JPEG", format="JPEG" - ) - - # Save animation - shutil.copy("temp.mp4", runoutputdir / "anim.mp4") - - # Save metadata - with open(runoutputdir / "details.json", "w") as f: - json.dump( - { - "run_id": run_id, - "num_steps": step_counter, - "planned_num_steps": num_steps, - "text_input": text_input, - "init_image": False if init_image is None else True, - "image_prompts": False if len(image_prompts) == 0 else True, - "continue_prev_run": continue_prev_run, - "prev_run_id": prev_run_id, - "seed": run.seed, - "Xdim": image_x, - "ydim": image_y, - "vqgan_ckpt": vqgan_ckpt, - "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), - "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), - "mse_weight": mse_weight, - "mse_weight_decay": mse_weight_decay, - "mse_weight_decay_steps": mse_weight_decay_steps, - }, - f, - indent=4, - ) - - status_text.text("Done!") # End of run - - except st.script_runner.StopException as e: - # Dump output to dashboard - print(f"Received Streamlit StopException") - status_text.text("Execution interruped, dumping outputs ...") - writer = imageio.get_writer("temp.mp4", fps=24) - for frame in frames: - writer.append_data(frame) - writer.close() - - # TODO: Make the following DRY - # Save to output folder if run completed - runoutputdir = outputdir / ( - run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id - ) - runoutputdir.mkdir() - - # Save final image - im.save(runoutputdir / "output.PNG", format="PNG") - - # Save init image - if init_image is not None: - init_image.save(runoutputdir / "init-image.JPEG", format="JPEG") - - # Save image prompts - for count, image_prompt in enumerate(image_prompts): - image_prompt.save( - runoutputdir / f"image-prompt-{count}.JPEG", format="JPEG" - ) - - # Save animation - shutil.copy("temp.mp4", runoutputdir / "anim.mp4") - - # Save metadata - with open(runoutputdir / "details.json", "w") as f: - json.dump( - { - "run_id": run_id, - "num_steps": step_counter, - "planned_num_steps": num_steps, - "text_input": text_input, - "init_image": False if init_image is None else True, - "image_prompts": False if len(image_prompts) == 0 else True, - "continue_prev_run": continue_prev_run, - "prev_run_id": prev_run_id, - "seed": run.seed, - "Xdim": image_x, - "ydim": image_y, - "vqgan_ckpt": vqgan_ckpt, - "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), - "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), - "mse_weight": mse_weight, - "mse_weight_decay": mse_weight_decay, - "mse_weight_decay_steps": mse_weight_decay_steps, - }, - f, - indent=4, - ) - status_text.text("Done!") # End of run - - -if __name__ == "__main__": - defaults = OmegaConf.load("defaults.yaml") - outputdir = Path("output") - if not outputdir.exists(): - outputdir.mkdir() - - st.set_page_config(page_title="VQGAN-CLIP playground") - st.title("VQGAN-CLIP playground") - - # Determine what weights are available in `assets/` - weights_dir = Path("assets").resolve() - available_weight_ckpts = list(weights_dir.glob("*.ckpt")) - available_weight_configs = list(weights_dir.glob("*.yaml")) - available_weights = [ - i.stem - for i in available_weight_ckpts - if i.stem in [j.stem for j in available_weight_configs] - ] - - # Set vqgan_imagenet_f16_1024 as default if possible - if "vqgan_imagenet_f16_1024" in available_weights: - default_weight_index = available_weights.index("vqgan_imagenet_f16_1024") - else: - default_weight_index = 0 - - # Start of input form - with st.form("form-inputs"): - # Only element not in the sidebar, but in the form - text_input = st.text_input( - "Text prompt", - help="VQGAN-CLIP will generate an image that best fits the prompt", - ) - radio = st.sidebar.radio( - "Model weights", - available_weights, - index=default_weight_index, - help="Choose which weights to load, trained on different datasets. Make sure the weights and configs are downloaded to `assets/` as per the README!", - ) - num_steps = st.sidebar.number_input( - "Num steps", - value=defaults["num_steps"], - min_value=-1, - max_value=None, - step=1, - help="Specify -1 to run indefinitely. Use Streamlit's stop button in the top right corner to terminate execution. The exception is caught so the most recent output will be dumped to dashboard", - ) - - image_x = st.sidebar.number_input( - "Xdim", value=defaults["Xdim"], help="Width of output image, in pixels" - ) - image_y = st.sidebar.number_input( - "ydim", value=defaults["ydim"], help="Height of output image, in pixels" - ) - set_seed = st.sidebar.checkbox( - "Set seed", - value=defaults["set_seed"], - help="Check to set random seed for reproducibility. Will add option to specify seed", - ) - - seed_widget = st.sidebar.empty() - if set_seed is True: - # Use text_input as number_input relies on JS - # which can't natively handle large numbers - # torch.seed() generates int w/ 19 or 20 chars! - seed_str = seed_widget.text_input( - "Seed", value=str(defaults["seed"]), help="Random seed to use" - ) - try: - seed = int(seed_str) - except ValueError as e: - st.error("seed input needs to be int") - else: - seed = None - - use_custom_starting_image = st.sidebar.checkbox( - "Use starting image", - value=defaults["use_starting_image"], - help="Check to add a starting image to the network", - ) - - starting_image_widget = st.sidebar.empty() - if use_custom_starting_image is True: - init_image = starting_image_widget.file_uploader( - "Upload starting image", - type=["png", "jpeg", "jpg"], - accept_multiple_files=False, - help="Starting image for the network, will be resized to fit specified dimensions", - ) - # Convert from UploadedFile object to PIL Image - if init_image is not None: - init_image: Image.Image = Image.open(init_image).convert( - "RGB" - ) # just to be sure - else: - init_image = None - - use_image_prompts = st.sidebar.checkbox( - "Add image prompt(s)", - value=defaults["use_image_prompts"], - help="Check to add image prompt(s), conditions the network similar to the text prompt", - ) - - image_prompts_widget = st.sidebar.empty() - if use_image_prompts is True: - image_prompts = image_prompts_widget.file_uploader( - "Upload image prompts(s)", - type=["png", "jpeg", "jpg"], - accept_multiple_files=True, - help="Image prompt(s) for the network, will be resized to fit specified dimensions", - ) - # Convert from UploadedFile object to PIL Image - if len(image_prompts) != 0: - image_prompts = [Image.open(i).convert("RGB") for i in image_prompts] - else: - image_prompts = [] - - continue_prev_run = st.sidebar.checkbox( - "Continue previous run", - value=defaults["continue_prev_run"], - help="Use existing image and existing weights for the next run. If yes, ignores 'Use starting image'", - ) - - use_mse_reg = st.sidebar.checkbox( - "Use MSE regularization", - value=defaults["use_mse_regularization"], - help="Check to add MSE regularization", - ) - mse_weight_widget = st.sidebar.empty() - mse_weight_decay_widget = st.sidebar.empty() - mse_weight_decay_steps = st.sidebar.empty() - - if use_mse_reg is True: - mse_weight = mse_weight_widget.number_input( - "MSE weight", - value=defaults["mse_weight"], - # min_value=0.0, # leave this out to allow creativity - step=0.05, - help="Set weights for MSE regularization", - ) - mse_weight_decay = mse_weight_decay_widget.number_input( - "Decay MSE weight by ...", - value=0.0, - # min_value=0.0, # leave this out to allow creativity - step=0.05, - help="Subtracts MSE weight by this amount at every step change. MSE weight change stops at zero", - ) - mse_weight_decay_steps = mse_weight_decay_steps.number_input( - "... every N steps", - value=0, - min_value=0, - step=1, - help="Number of steps to subtract MSE weight. Leave zero for no weight decay", - ) - submitted = st.form_submit_button("Run!") - # End of form - - status_text = st.empty() - status_text.text("Pending input prompt") - step_progress_bar = st.progress(0) - - im_display_slot = st.empty() - vid_display_slot = st.empty() - debug_slot = st.empty() - - if "prev_im" in st.session_state: - im_display_slot.image( - st.session_state["prev_im"], caption="Output image", output_format="PNG" - ) - - with st.beta_expander("Expand for README"): - with open("README.md", "r") as f: - # description = f.read() - # Preprocess links to redirect to github - # Thank you https://discuss.streamlit.io/u/asehmi, works like a charm! - # ref: https://discuss.streamlit.io/t/image-in-markdown/13274/8 - readme_lines = f.readlines() - readme_buffer = [] - images = ["docs/ui.jpeg", "docs/four-seasons-20210808.png"] - for line in readme_lines: - readme_buffer.append(line) - for image in images: - if image in line: - st.markdown(" ".join(readme_buffer[:-1])) - st.image( - f"https://raw.githubusercontent.com/tnwei/vqgan-clip-app/main/{image}" - ) - readme_buffer.clear() - st.markdown(" ".join(readme_buffer)) - - # st.write(description) - - if submitted: - # debug_slot.write(st.session_state) # DEBUG - status_text.text("Loading weights ...") - generate_image( - # Inputs - text_input=text_input, - vqgan_ckpt=radio, - num_steps=num_steps, - image_x=int(image_x), - image_y=int(image_y), - seed=int(seed) if set_seed is True else None, - init_image=init_image, - image_prompts=image_prompts, - continue_prev_run=continue_prev_run, - mse_weight=mse_weight, - mse_weight_decay=mse_weight_decay, - mse_weight_decay_steps=mse_weight_decay_steps, - ) - vid_display_slot.video("temp.mp4") - # debug_slot.write(st.session_state) # DEBUG