Skip to content

Commit

Permalink
feat: add mse regularization (PR #4)
Browse files Browse the repository at this point in the history
* feat: minimal working example
* feat: added widget for mse init weight
* feat: added MSE schedule
* feat: nested mse option, added to metadata output
* refactor: trimmed unused code stubs
* refactor: separated out non MSE features
* refactor: rm'ed temp modules
* refactor: merged all non-UI code
* refactor: merged all MSE code to existing modules
  • Loading branch information
tnwei authored Oct 13, 2021
1 parent f2854e4 commit ae32237
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 2 deletions.
52 changes: 52 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def generate_image(
image_prompts: List[Image.Image] = [],
continue_prev_run: bool = False,
seed: Optional[int] = None,
mse_weight: float = 0,
mse_weight_decay: float = 0,
mse_weight_decay_steps: int = 0,
) -> None:

### Init -------------------------------------------------------------------
Expand All @@ -47,6 +50,9 @@ def generate_image(
init_image=init_image,
image_prompts=image_prompts,
continue_prev_run=continue_prev_run,
mse_weight=mse_weight,
mse_weight_decay=mse_weight_decay,
mse_weight_decay_steps=mse_weight_decay_steps,
)

### Load model -------------------------------------------------------------
Expand Down Expand Up @@ -171,6 +177,9 @@ def generate_image(
"vqgan_ckpt": vqgan_ckpt,
"start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"),
"end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"),
"mse_weight": mse_weight,
"mse_weight_decay": mse_weight_decay,
"mse_weight_decay_steps": mse_weight_decay_steps,
},
f,
indent=4,
Expand Down Expand Up @@ -228,6 +237,9 @@ def generate_image(
"vqgan_ckpt": vqgan_ckpt,
"start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"),
"end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"),
"mse_weight": mse_weight,
"mse_weight_decay": mse_weight_decay,
"mse_weight_decay_steps": mse_weight_decay_steps,
},
f,
indent=4,
Expand Down Expand Up @@ -356,6 +368,43 @@ def generate_image(
value=defaults["continue_prev_run"],
help="Use existing image and existing weights for the next run. If yes, ignores 'Use starting image'",
)

use_mse_reg = st.sidebar.checkbox(
"Use MSE regularization",
value=defaults["use_mse_regularization"],
help="Check to add MSE regularization",
)
mse_weight_widget = st.sidebar.empty()
mse_weight_decay_widget = st.sidebar.empty()
mse_weight_decay_steps = st.sidebar.empty()

if use_mse_reg is True:
mse_weight = mse_weight_widget.number_input(
"MSE weight",
value=defaults["mse_weight"],
# min_value=0.0, # leave this out to allow creativity
step=0.05,
help="Set weights for MSE regularization",
)
mse_weight_decay = mse_weight_decay_widget.number_input(
"Decay MSE weight by ...",
value=0.0,
# min_value=0.0, # leave this out to allow creativity
step=0.05,
help="Subtracts MSE weight by this amount at every step change. MSE weight change stops at zero",
)
mse_weight_decay_steps = mse_weight_decay_steps.number_input(
"... every N steps",
value=0,
min_value=0,
step=1,
help="Number of steps to subtract MSE weight. Leave zero for no weight decay",
)
else:
mse_weight = 0
mse_weight_decay = 0
mse_weight_decay_steps = 0

submitted = st.form_submit_button("Run!")
# End of form

Expand Down Expand Up @@ -408,6 +457,9 @@ def generate_image(
init_image=init_image,
image_prompts=image_prompts,
continue_prev_run=continue_prev_run,
mse_weight=mse_weight,
mse_weight_decay=mse_weight_decay,
mse_weight_decay_steps=mse_weight_decay_steps,
)
vid_display_slot.video("temp.mp4")
# debug_slot.write(st.session_state) # DEBUG
2 changes: 2 additions & 0 deletions defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ seed: 0
use_starting_image: false
use_image_prompts: false
continue_prev_run: false
mse_weight: 0.0
use_mse_regularization: false
49 changes: 47 additions & 2 deletions logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def __init__(
image_prompts: List[Image.Image] = [],
continue_prev_run: bool = False,
seed: Optional[int] = None,
mse_weight=0.5,
mse_weight_decay=0.1,
mse_weight_decay_steps=50,
# use_augs: bool = True,
# noise_fac: float = 0.1,
# use_noise: Optional[float] = None,
# mse_withzeros=True,
## **kwargs, # Use this to receive Streamlit objects ## Call from main UI
) -> None:
super().__init__()
Expand Down Expand Up @@ -93,7 +100,7 @@ def __init__(
noise_prompt_weights=[],
size=[int(image_x), int(image_y)],
init_image=init_image,
init_weight=0.0,
init_weight=mse_weight,
# clip.available_models()
# ['RN50', 'RN101', 'RN50x4', 'ViT-B/32']
# Visual Transformer seems to be the smallest
Expand All @@ -111,6 +118,16 @@ def __init__(
self.device = device
print("Using device:", device)

self.iterate_counter = 0
# self.use_augs = use_augs
# self.noise_fac = noise_fac
# self.use_noise = use_noise
# self.mse_withzeros = mse_withzeros
self.init_mse_weight = mse_weight
self.mse_weight = mse_weight
self.mse_weight_decay = mse_weight_decay
self.mse_weight_decay_steps = mse_weight_decay_steps

def load_model(
self, prev_model: nn.Module = None, prev_perceptor: nn.Module = None
) -> Optional[Tuple[nn.Module, nn.Module]]:
Expand Down Expand Up @@ -217,7 +234,32 @@ def _ascend_txt(self) -> List:
result = []

if self.args.init_weight:
result.append(F.mse_loss(self.z, self.z_orig) * self.args.init_weight / 2)
result.append(F.mse_loss(self.z, self.z_orig) * self.init_mse_weight / 2)

# MSE regularization scheduler
with torch.no_grad():
# if not the first step
# and is time for step change
# and both weight decay steps and magnitude are nonzero
# and MSE isn't zero already
if (
self.iterate_counter > 0
and self.iterate_counter % self.mse_weight_decay_steps == 0
and self.mse_weight_decay != 0
and self.mse_weight_decay_steps != 0
and self.mse_weight != 0
):
self.mse_weight = self.mse_weight - self.mse_weight_decay

# Don't allow changing sign
# Basically, caps MSE at zero if decreasing from positive
# But, also prevents MSE from becoming positive if -MSE intended
if self.init_mse_weight > 0:
self.mse_weight = max(self.mse_weight, 0)
else:
self.mse_weight = min(self.mse_weight, 0)

print(f"updated mse weight: {self.mse_weight}")

for prompt in self.pMs:
result.append(prompt(iii))
Expand All @@ -239,5 +281,8 @@ def iterate(self) -> Tuple[List[float], Image.Image]:
with torch.no_grad():
self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max))

# Advance iteration counter
self.iterate_counter += 1

# Output stuff useful for humans
return [loss.item() for loss in losses], im

0 comments on commit ae32237

Please sign in to comment.