Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Cutn to the web interface sidebar. #20

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''

---

**Describe the bug**
A clear and concise description of what the bug is.

**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error

**Expected behavior**
A clear and concise description of what you expected to happen.

**Screenshots**
If applicable, add screenshots to help explain your problem.

**Desktop (please complete the following information):**
- OS: [e.g. Linux, Windows 10]
- Browser [e.g. Chrome, Safari, Brave]
- Version [e.g. 22]

**Additional context**
Add any other context about the problem here.
20 changes: 20 additions & 0 deletions .github/ISSUE_TEMPLATE/feature_request.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''

---

**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

**Describe the solution you'd like**
A clear and concise description of what you want to happen.

**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.

**Additional context**
Add any other context or screenshots about the feature request here.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
# VQGAN weights
assets/*.ckpt
assets/*.yaml
assets/*

# Outputs
output*

# samples
samples/*

# Test data
test-samples/
2 changes: 1 addition & 1 deletion .streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[server]
# Default is 200 MB
maxUploadSize = 10
maxUploadSize = 5000
279 changes: 239 additions & 40 deletions app.py

Large diffs are not rendered by default.

20 changes: 15 additions & 5 deletions defaults.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
# Modify for different systems, e.g. larger default xdim/ydim for more powerful GPUs
num_steps: 500
Xdim: 640
ydim: 480
use_clip_model: false
clip_model: ViT-B/32
num_steps: -1
Xdim: 662
ydim: 360
set_seed: false
seed: 0
use_cutn: false
cutn: 32
cut_pow: 1.0
custom_step_size: false
step_size: 0.05
use_custom_opt: false
opt_name: Adam
use_starting_image: false
use_image_prompts: false
continue_prev_run: false
mse_weight: 0.5
mse_weight_decay: 0.1
mse_weight_decay_steps: 50
use_mse_regularization: false
use_tv_loss_regularization: true
tv_loss_weight: 1e-3
use_tv_loss_regularization: false
# best values for tv_loss_weight are 0.000085, 0.0001 or 0.0002
tv_loss_weight: 0.000085
4 changes: 2 additions & 2 deletions diffusion_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def generate_image(
frames = []

try:
# Try block catches st.script_runner.StopExecution, no need of a dedicated stop button
# Try block catches st.StopExecution, no need of a dedicated stop button
# Reason is st.form is meant to be self-contained either within sidebar, or in main body
# The way the form is implemented in this app splits the form across both regions
# This is intended to prevent the model settings from crowding the main body
Expand Down Expand Up @@ -181,7 +181,7 @@ def generate_image(

status_text.text("Done!") # End of run

except st.script_runner.StopException as e:
except st.StopException as e:
# Dump output to dashboard
print(f"Received Streamlit StopException")
status_text.text("Execution interruped, dumping outputs ...")
Expand Down
10 changes: 6 additions & 4 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ channels:
- conda-forge
- defaults
dependencies:
- pytorch::pytorch=1.10.0
- pytorch::torchvision=0.11.1
- cudatoolkit=10.2
- pytorch::pytorch=1.12.0
- pytorch::torchvision=0.13.1
- cudatoolkit=10.2 # The cudatoolkit library could also be updated to 11.3 but might give some troubles with older GPUs, for an RTX 3050 or higher cudatoolkit=11.3 is recommended.
- omegaconf
- pytorch-lightning
- pytorch-lightning=1.5.8 # For compatibility
- tqdm
- regex
- kornia
Expand All @@ -32,6 +32,8 @@ dependencies:
# - imgtag
- einops
- transformers
- torch-optimizer
- retry
- git+https://github.com/openai/CLIP
# For guided diffusion
- lpips
Expand Down
108 changes: 94 additions & 14 deletions logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,57 @@
import cv2
import numpy as np
import kornia.augmentation as K
from torch_optimizer import DiffGrad, AdamP, RAdam


# Set the optimiser
def get_opt(opt_name, z, opt_lr):
"""
List of optimizers
Adadelta: Implements Adadelta algorithm.
Adagrad: Implements Adagrad algorithm.
Adam: Implements Adam algorithm.
AdamW: Implements AdamW algorithm.
Adamax: Implements Adamax algorithm (a variant of Adam based on infinity norm).
ASGD: Implements Averaged Stochastic Gradient Descent.
NAdam: Implements NAdam algorithm.
RAdam: Implements RAdam algorithm.
RMSprop: Implements RMSprop algorithm.
Rprop: Implements the resilient backpropagation algorithm.
SGD: Implements stochastic gradient descent (optionally with momentum)."""

if opt_name == "Adam":
opt = optim.Adam([z], lr=opt_lr)
elif opt_name == "AdamW":
opt = optim.AdamW([z], lr=opt_lr)
elif opt_name == "Adagrad":
opt = optim.Adagrad([z], lr=opt_lr)
elif opt_name == "Adamax":
opt = optim.Adamax([z], lr=opt_lr)
elif opt_name == "AdamP":
opt = AdamP([z], lr=opt_lr)
elif opt_name == "Adadelta":
opt = optim.Adadelta([z], lr=opt_lr, eps=1e-9, weight_decay=1e-9)
elif opt_name == "ASGD":
opt = optim.ASGD([z], lr=opt_lr)
elif opt_name == "DiffGrad":
opt = DiffGrad([z], lr=opt_lr, eps=1e-9, weight_decay=1e-9)
elif opt_name == "NAdam":
opt = optim.NAdam([z], lr=opt_lr)
elif opt_name == "RAdam":
opt = RAdam([z], lr=opt_lr)
elif opt_name == "RMSprop":
opt = optim.RMSprop([z], lr=opt_lr)
elif opt_name == "Rprop":
opt = optim.Rprop([z], lr=opt_lr)
elif opt_name == "SGD":
opt = optim.SGD([z], lr=opt_lr)

else:
print(f"Unknown optimiser: {opt_name} | Are choices broken?")
opt = optim.Adam([z], lr=opt_lr)
return opt

class Run:
"""
Subclass this to house your own implementation of CLIP-based image generation
Expand Down Expand Up @@ -63,17 +112,22 @@ def __init__(
self,
text_input: str = "the first day of the waters",
vqgan_ckpt: str = "vqgan_imagenet_f16_16384",
clip_model: str = "ViT-B/32",
num_steps: int = 300,
image_x: int = 300,
image_y: int = 300,
init_image: Optional[Image.Image] = None,
image_prompts: List[Image.Image] = [],
continue_prev_run: bool = False,
seed: Optional[int] = None,
cutn: int = 32,
cut_pow: float = 1.0,
step_size: float = 0.05,
opt_name: str = "Adam",
mse_weight=0.5,
mse_weight_decay=0.1,
mse_weight_decay_steps=50,
tv_loss_weight=1e-3,
tv_loss_weight=0.000085,
use_cutout_augmentations: bool = True,
# use_augs: bool = True,
# noise_fac: float = 0.1,
Expand All @@ -86,18 +140,24 @@ def __init__(
rotation_angle: float = 0,
zoom_factor: float = 1,
transform_interval: int = 10,
device: Optional[torch.device] = None,
device: Optional[torch.device] = "cpu",
) -> None:
super().__init__()
self.text_input = text_input
self.vqgan_ckpt = vqgan_ckpt
self.clip_model = clip_model
self.num_steps = num_steps
self.image_x = image_x
self.image_y = image_y
self.init_image = init_image
self.image_prompts = image_prompts
self.continue_prev_run = continue_prev_run
self.seed = seed
self.cutn = cutn
self.cut_pow = cut_pow
self.step_size = step_size
self.opt_name = opt_name
self.device = device

# Setup ------------------------------------------------------------------------------
# Split text by "|" symbol
Expand All @@ -115,22 +175,26 @@ def __init__(
init_image=init_image,
init_weight=mse_weight,
# clip.available_models()
# ['RN50', 'RN101', 'RN50x4', 'ViT-B/32']
# ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
# Visual Transformer seems to be the smallest
clip_model="ViT-B/32",
clip_model=clip_model,
vqgan_config=f"assets/{vqgan_ckpt}.yaml",
vqgan_checkpoint=f"assets/{vqgan_ckpt}.ckpt",
step_size=0.05,
cutn=64,
cut_pow=1.0,
cutn=cutn,
cut_pow=cut_pow,
step_size=step_size,
opt_name=opt_name,
display_freq=50,
seed=seed,
device=device,
)

if device is None:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device is None or device == "cpu":
#self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
else:
self.device = device
#self.device = device
self.device = torch.device(f"{device}" if torch.cuda.is_available() else "cpu")

print("Using device:", device)

Expand All @@ -155,6 +219,8 @@ def __init__(
self.rotation_angle = rotation_angle
self.zoom_factor = zoom_factor
self.transform_interval = transform_interval



def load_model(
self, prev_model: nn.Module = None, prev_perceptor: nn.Module = None
Expand Down Expand Up @@ -214,6 +280,7 @@ def model_init(self, init_image: Image.Image = None) -> None:
None, :, None, None
]


if self.seed is not None:
torch.manual_seed(self.seed)
else:
Expand All @@ -239,7 +306,8 @@ def model_init(self, init_image: Image.Image = None) -> None:
self.z = self.z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
self.z_orig = self.z.clone()
self.z.requires_grad_(True)
self.opt = optim.Adam([self.z], lr=self.args.step_size)
#self.opt = optim.Adam([self.z], lr=self.args.step_size)
self.opt = get_opt(self.opt_name, self.z, self.args.step_size)

self.normalize = transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
Expand Down Expand Up @@ -315,11 +383,14 @@ def _ascend_txt(self) -> List:
result[f"prompt_loss_{count}"] = prompt(iii)

return result

def iterate(self) -> Tuple[List[float], Image.Image]:
if not self.use_scrolling_zooming:
# Forward prop
self.opt.zero_grad()
#for param in self.model.parameters():
#param.grad = None

losses = self._ascend_txt()

# Grab an image
Expand Down Expand Up @@ -377,11 +448,16 @@ def iterate(self) -> Tuple[List[float], Image.Image]:
TF.to_tensor(transformed_im).to(self.device).unsqueeze(0) * 2 - 1
)
self.z.requires_grad_(True)
self.opt = optim.Adam([self.z], lr=self.args.step_size)

#self.opt = optim.Adam([self.z], lr=self.args.step_size)
self.opt = get_opt(self.opt_name, self.z, self.args.step_size)

for _ in range(self.transform_interval):
# Forward prop
self.opt.zero_grad()
#for param in self.model.parameters():
#param.grad = None

losses = self._ascend_txt()

# Grab an image
Expand All @@ -393,9 +469,13 @@ def iterate(self) -> Tuple[List[float], Image.Image]:
self.opt.step()
with torch.no_grad():
self.z.copy_(self.z.maximum(self.z_min).minimum(self.z_max))

# Advance iteration counter
self.iterate_counter += 1

for param_group in self.opt.param_groups:
#print (param_group)
print (f"Learning Rate: {param_group['lr']}")

print(
f"Step {self.iterate_counter} losses: {[(i, j.item()) for i, j in losses.items()]}"
Expand Down