From ae32237d7edc85ec0630e43052cd95a37066d00d Mon Sep 17 00:00:00 2001 From: tnwei <12769364+tnwei@users.noreply.github.com> Date: Thu, 14 Oct 2021 01:20:59 +0800 Subject: [PATCH] feat: add mse regularization (PR #4) * feat: minimal working example * feat: added widget for mse init weight * feat: added MSE schedule * feat: nested mse option, added to metadata output * refactor: trimmed unused code stubs * refactor: separated out non MSE features * refactor: rm'ed temp modules * refactor: merged all non-UI code * refactor: merged all MSE code to existing modules --- app.py | 52 +++++++++++++++++++++++++++++++++++++++++++++++++++ defaults.yaml | 2 ++ logic.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 101 insertions(+), 2 deletions(-) diff --git a/app.py b/app.py index 33b36a3..790442c 100644 --- a/app.py +++ b/app.py @@ -34,6 +34,9 @@ def generate_image( image_prompts: List[Image.Image] = [], continue_prev_run: bool = False, seed: Optional[int] = None, + mse_weight: float = 0, + mse_weight_decay: float = 0, + mse_weight_decay_steps: int = 0, ) -> None: ### Init ------------------------------------------------------------------- @@ -47,6 +50,9 @@ def generate_image( init_image=init_image, image_prompts=image_prompts, continue_prev_run=continue_prev_run, + mse_weight=mse_weight, + mse_weight_decay=mse_weight_decay, + mse_weight_decay_steps=mse_weight_decay_steps, ) ### Load model ------------------------------------------------------------- @@ -171,6 +177,9 @@ def generate_image( "vqgan_ckpt": vqgan_ckpt, "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), + "mse_weight": mse_weight, + "mse_weight_decay": mse_weight_decay, + "mse_weight_decay_steps": mse_weight_decay_steps, }, f, indent=4, @@ -228,6 +237,9 @@ def generate_image( "vqgan_ckpt": vqgan_ckpt, "start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"), "end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"), + "mse_weight": mse_weight, + "mse_weight_decay": mse_weight_decay, + "mse_weight_decay_steps": mse_weight_decay_steps, }, f, indent=4, @@ -356,6 +368,43 @@ def generate_image( value=defaults["continue_prev_run"], help="Use existing image and existing weights for the next run. If yes, ignores 'Use starting image'", ) + + use_mse_reg = st.sidebar.checkbox( + "Use MSE regularization", + value=defaults["use_mse_regularization"], + help="Check to add MSE regularization", + ) + mse_weight_widget = st.sidebar.empty() + mse_weight_decay_widget = st.sidebar.empty() + mse_weight_decay_steps = st.sidebar.empty() + + if use_mse_reg is True: + mse_weight = mse_weight_widget.number_input( + "MSE weight", + value=defaults["mse_weight"], + # min_value=0.0, # leave this out to allow creativity + step=0.05, + help="Set weights for MSE regularization", + ) + mse_weight_decay = mse_weight_decay_widget.number_input( + "Decay MSE weight by ...", + value=0.0, + # min_value=0.0, # leave this out to allow creativity + step=0.05, + help="Subtracts MSE weight by this amount at every step change. MSE weight change stops at zero", + ) + mse_weight_decay_steps = mse_weight_decay_steps.number_input( + "... every N steps", + value=0, + min_value=0, + step=1, + help="Number of steps to subtract MSE weight. Leave zero for no weight decay", + ) + else: + mse_weight = 0 + mse_weight_decay = 0 + mse_weight_decay_steps = 0 + submitted = st.form_submit_button("Run!") # End of form @@ -408,6 +457,9 @@ def generate_image( init_image=init_image, image_prompts=image_prompts, continue_prev_run=continue_prev_run, + mse_weight=mse_weight, + mse_weight_decay=mse_weight_decay, + mse_weight_decay_steps=mse_weight_decay_steps, ) vid_display_slot.video("temp.mp4") # debug_slot.write(st.session_state) # DEBUG diff --git a/defaults.yaml b/defaults.yaml index 1ce66b6..16c4902 100644 --- a/defaults.yaml +++ b/defaults.yaml @@ -7,3 +7,5 @@ seed: 0 use_starting_image: false use_image_prompts: false continue_prev_run: false +mse_weight: 0.0 +use_mse_regularization: false \ No newline at end of file diff --git a/logic.py b/logic.py index 8a64fbb..741ea5d 100644 --- a/logic.py +++ b/logic.py @@ -66,6 +66,13 @@ def __init__( image_prompts: List[Image.Image] = [], continue_prev_run: bool = False, seed: Optional[int] = None, + mse_weight=0.5, + mse_weight_decay=0.1, + mse_weight_decay_steps=50, + # use_augs: bool = True, + # noise_fac: float = 0.1, + # use_noise: Optional[float] = None, + # mse_withzeros=True, ## **kwargs, # Use this to receive Streamlit objects ## Call from main UI ) -> None: super().__init__() @@ -93,7 +100,7 @@ def __init__( noise_prompt_weights=[], size=[int(image_x), int(image_y)], init_image=init_image, - init_weight=0.0, + init_weight=mse_weight, # clip.available_models() # ['RN50', 'RN101', 'RN50x4', 'ViT-B/32'] # Visual Transformer seems to be the smallest @@ -111,6 +118,16 @@ def __init__( self.device = device print("Using device:", device) + self.iterate_counter = 0 + # self.use_augs = use_augs + # self.noise_fac = noise_fac + # self.use_noise = use_noise + # self.mse_withzeros = mse_withzeros + self.init_mse_weight = mse_weight + self.mse_weight = mse_weight + self.mse_weight_decay = mse_weight_decay + self.mse_weight_decay_steps = mse_weight_decay_steps + def load_model( self, prev_model: nn.Module = None, prev_perceptor: nn.Module = None ) -> Optional[Tuple[nn.Module, nn.Module]]: @@ -217,7 +234,32 @@ def _ascend_txt(self) -> List: result = [] if self.args.init_weight: - result.append(F.mse_loss(self.z, self.z_orig) * self.args.init_weight / 2) + result.append(F.mse_loss(self.z, self.z_orig) * self.init_mse_weight / 2) + + # MSE regularization scheduler + with torch.no_grad(): + # if not the first step + # and is time for step change + # and both weight decay steps and magnitude are nonzero + # and MSE isn't zero already + if ( + self.iterate_counter > 0 + and self.iterate_counter % self.mse_weight_decay_steps == 0 + and self.mse_weight_decay != 0 + and self.mse_weight_decay_steps != 0 + and self.mse_weight != 0 + ): + self.mse_weight = self.mse_weight - self.mse_weight_decay + + # Don't allow changing sign + # Basically, caps MSE at zero if decreasing from positive + # But, also prevents MSE from becoming positive if -MSE intended + if self.init_mse_weight > 0: + self.mse_weight = max(self.mse_weight, 0) + else: + self.mse_weight = min(self.mse_weight, 0) + + print(f"updated mse weight: {self.mse_weight}") for prompt in self.pMs: result.append(prompt(iii)) @@ -239,5 +281,8 @@ def iterate(self) -> Tuple[List[float], Image.Image]: with torch.no_grad(): self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max)) + # Advance iteration counter + self.iterate_counter += 1 + # Output stuff useful for humans return [loss.item() for loss in losses], im