-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathddim_inversion.py
60 lines (47 loc) · 2.49 KB
/
ddim_inversion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from typing import Union, Tuple, Optional
import matplotlib.pyplot as plt
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, AutoencoderKL, DDIMScheduler
from torchvision import transforms as tvt
def load_image(imgname: str, target_size: Optional[Union[int, Tuple[int, int]]] = None) -> torch.Tensor:
pil_img = Image.open(imgname).convert('RGB')
if target_size is not None:
if isinstance(target_size, int):
target_size = (target_size, target_size)
pil_img = pil_img.resize(target_size, Image.Resampling.LANCZOS)
return tvt.ToTensor()(pil_img)[None, ...] # add batch dimension
def img_to_latents(x: torch.Tensor, vae: AutoencoderKL):
x = 2. * x - 1.
posterior = vae.encode(x).latent_dist
latents = posterior.mean * 0.18215
return latents
@torch.no_grad()
def ddim_inversion(imgname: str, num_steps: int = 50, verify: Optional[bool] = False) -> torch.Tensor:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float16
inverse_scheduler = DDIMInverseScheduler.from_pretrained('stabilityai/stable-diffusion-2-1', subfolder='scheduler')
pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-1',
scheduler=inverse_scheduler,
safety_checker=None,
torch_dtype=dtype)
pipe.to(device)
vae = pipe.vae
input_img = load_image(imgname).to(device=device, dtype=dtype)
latents = img_to_latents(input_img, vae)
inv_latents, _ = pipe(prompt="", negative_prompt="", guidance_scale=1.,
width=input_img.shape[-1], height=input_img.shape[-2],
output_type='latent', return_dict=False,
num_inference_steps=num_steps, latents=latents)
# verify
if verify:
pipe.scheduler = DDIMScheduler.from_pretrained('stabilityai/stable-diffusion-2-1', subfolder='scheduler')
image = pipe(prompt="", negative_prompt="", guidance_scale=1.,
num_inference_steps=num_steps, latents=inv_latents)
fig, ax = plt.subplots(1, 2)
ax[0].imshow(tvt.ToPILImage()(input_img[0]))
ax[1].imshow(image.images[0])
plt.show()
return inv_latents
if __name__ == '__main__':
ddim_inversion('./poike.png', num_steps=250, verify=True)