-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmain_run.py
147 lines (122 loc) · 7.6 KB
/
main_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import argparse
from diffusers import StableDiffusionPipeline
from diffusers import DDIMScheduler
import os
from prompt_to_prompt.ptp_classes import AttentionStore, AttentionReplace, AttentionRefine, EmptyControl,load_512
from prompt_to_prompt.ptp_utils import register_attention_control, text2image_ldm_stable, view_images
from ddm_inversion.inversion_utils import inversion_forward_process, inversion_reverse_process
from ddm_inversion.utils import image_grid,dataset_from_yaml
from torch import autocast, inference_mode
from ddm_inversion.ddim_inversion import ddim_inversion
import calendar
import time
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device_num", type=int, default=0)
parser.add_argument("--cfg_src", type=float, default=3.5)
parser.add_argument("--cfg_tar", type=float, default=15)
parser.add_argument("--num_diffusion_steps", type=int, default=100)
parser.add_argument("--dataset_yaml", default="test.yaml")
parser.add_argument("--eta", type=float, default=1)
parser.add_argument("--mode", default="our_inv", help="modes: our_inv,p2pinv,p2pddim,ddim")
parser.add_argument("--skip", type=int, default=36)
parser.add_argument("--xa", type=float, default=0.6)
parser.add_argument("--sa", type=float, default=0.2)
args = parser.parse_args()
full_data = dataset_from_yaml(args.dataset_yaml)
# create scheduler
# load diffusion model
model_id = "CompVis/stable-diffusion-v1-4"
# model_id = "stable_diff_local" # load local save of model (for internet problems)
device = f"cuda:{args.device_num}"
cfg_scale_src = args.cfg_src
cfg_scale_tar_list = [args.cfg_tar]
eta = args.eta # = 1
skip_zs = [args.skip]
xa_sa_string = f'_xa_{args.xa}_sa{args.sa}_' if args.mode=='p2pinv' else '_'
current_GMT = time.gmtime()
time_stamp = calendar.timegm(current_GMT)
# load/reload model:
ldm_stable = StableDiffusionPipeline.from_pretrained(model_id).to(device)
for i in range(len(full_data)):
current_image_data = full_data[i]
image_path = current_image_data['init_img']
image_path = '.' + image_path
image_folder = image_path.split('/')[1] # after '.'
prompt_src = current_image_data.get('source_prompt', "") # default empty string
prompt_tar_list = current_image_data['target_prompts']
if args.mode=="p2pddim" or args.mode=="ddim":
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
ldm_stable.scheduler = scheduler
else:
ldm_stable.scheduler = DDIMScheduler.from_config(model_id, subfolder = "scheduler")
ldm_stable.scheduler.set_timesteps(args.num_diffusion_steps)
# load image
offsets=(0,0,0,0)
x0 = load_512(image_path, *offsets, device)
# vae encode image
with autocast("cuda"), inference_mode():
w0 = (ldm_stable.vae.encode(x0).latent_dist.mode() * 0.18215).float()
# find Zs and wts - forward process
if args.mode=="p2pddim" or args.mode=="ddim":
wT = ddim_inversion(ldm_stable, w0, prompt_src, cfg_scale_src)
else:
wt, zs, wts = inversion_forward_process(ldm_stable, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=args.num_diffusion_steps)
# iterate over decoder prompts
for k in range(len(prompt_tar_list)):
prompt_tar = prompt_tar_list[k]
save_path = os.path.join(f'./results/', args.mode+xa_sa_string+str(time_stamp), image_path.split(sep='.')[0], 'src_' + prompt_src.replace(" ", "_"), 'dec_' + prompt_tar.replace(" ", "_"))
os.makedirs(save_path, exist_ok=True)
# Check if number of words in encoder and decoder text are equal
src_tar_len_eq = (len(prompt_src.split(" ")) == len(prompt_tar.split(" ")))
for cfg_scale_tar in cfg_scale_tar_list:
for skip in skip_zs:
if args.mode=="our_inv":
# reverse process (via Zs and wT)
controller = AttentionStore()
register_attention_control(ldm_stable, controller)
w0, _ = inversion_reverse_process(ldm_stable, xT=wts[args.num_diffusion_steps-skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[:(args.num_diffusion_steps-skip)], controller=controller)
elif args.mode=="p2pinv":
# inversion with attention replace
cfg_scale_list = [cfg_scale_src, cfg_scale_tar]
prompts = [prompt_src, prompt_tar]
if src_tar_len_eq:
controller = AttentionReplace(prompts, args.num_diffusion_steps, cross_replace_steps=args.xa, self_replace_steps=args.sa, model=ldm_stable)
else:
# Should use Refine for target prompts with different number of tokens
controller = AttentionRefine(prompts, args.num_diffusion_steps, cross_replace_steps=args.xa, self_replace_steps=args.sa, model=ldm_stable)
register_attention_control(ldm_stable, controller)
w0, _ = inversion_reverse_process(ldm_stable, xT=wts[args.num_diffusion_steps-skip], etas=eta, prompts=prompts, cfg_scales=cfg_scale_list, prog_bar=True, zs=zs[:(args.num_diffusion_steps-skip)], controller=controller)
w0 = w0[1].unsqueeze(0)
elif args.mode=="p2pddim" or args.mode=="ddim":
# only z=0
if skip != 0:
continue
prompts = [prompt_src, prompt_tar]
if args.mode=="p2pddim":
if src_tar_len_eq:
controller = AttentionReplace(prompts, args.num_diffusion_steps, cross_replace_steps=.8, self_replace_steps=0.4, model=ldm_stable)
# Should use Refine for target prompts with different number of tokens
else:
controller = AttentionRefine(prompts, args.num_diffusion_steps, cross_replace_steps=.8, self_replace_steps=0.4, model=ldm_stable)
else:
controller = EmptyControl()
register_attention_control(ldm_stable, controller)
# perform ddim inversion
cfg_scale_list = [cfg_scale_src, cfg_scale_tar]
w0, latent = text2image_ldm_stable(ldm_stable, prompts, controller, args.num_diffusion_steps, cfg_scale_list, None, wT)
w0 = w0[1:2]
else:
raise NotImplementedError
# vae decode image
with autocast("cuda"), inference_mode():
x0_dec = ldm_stable.vae.decode(1 / 0.18215 * w0).sample
if x0_dec.dim()<4:
x0_dec = x0_dec[None,:,:,:]
img = image_grid(x0_dec)
# same output
current_GMT = time.gmtime()
time_stamp_name = calendar.timegm(current_GMT)
image_name_png = f'cfg_d_{cfg_scale_tar}_' + f'skip_{skip}_{time_stamp_name}' + ".png"
save_full_path = os.path.join(save_path, image_name_png)
img.save(save_full_path)