-
Notifications
You must be signed in to change notification settings - Fork 13
/
vae.py
158 lines (139 loc) · 4.27 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
from diffusers import AutoencoderKL
DTYPE = torch.float16
DEVICE = "cuda:0"
class SDv1_VAE:
scale = 1/8
channels = 4
def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
self.device = device
self.dtype = dtype
self.model = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse"
)
self.model.eval().to(self.dtype).to(self.device)
if dec_only:
del self.model.encoder
def encode(self, image):
image = image.to(self.dtype).to(self.device)
image = (image * 2.0) - 1.0 # assuming input is [0;1]
with torch.no_grad():
latent = self.model.encode(image).latent_dist.sample()
return latent.to(image.dtype).to(image.device)
def decode(self, latent, grad=False):
latent = latent.to(self.dtype).to(self.device)
if grad:
out = self.model.decode(latent)[0]
else:
with torch.no_grad():
out = self.model.decode(latent).sample
out = torch.clamp(out, min=-1.0, max=1.0)
out = (out + 1.0) / 2.0
return out.to(latent.dtype).to(latent.device)
class SDXL_VAE(SDv1_VAE):
scale = 1/8
channels = 4
def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
self.device = device
self.dtype = dtype
self.model = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix"
)
self.model.eval().to(self.dtype).to(self.device)
if dec_only:
del self.model.encoder
class SDv3_VAE(SDv1_VAE):
scale = 1/8
channels = 16
def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
self.device = device
self.dtype = dtype
self.model = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
subfolder="vae"
)
self.model.eval().to(self.dtype).to(self.device)
if dec_only:
del self.model.encoder
class FLUX_VAE(SDv1_VAE):
scale = 1/8
channels = 16
def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
self.device = device
self.dtype = torch.bfloat16 if dec_only else dtype # decoder NaNs randomly
self.model = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="vae"
)
self.model.eval().to(self.dtype).to(self.device)
if dec_only:
del self.model.encoder
class CascadeC_VAE(SDv1_VAE):
scale = 1/32
channels = 16
def __init__(self, device=DEVICE, dtype=DTYPE, **kwargs):
self.device = device
self.dtype = dtype
#For now this is just piggybacking off of koyha-ss/sd-scripts
from library import stable_cascade as sc
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
self.model = sc.EfficientNetEncoder()
self.model.load_state_dict(load_file(
str(hf_hub_download(
repo_id = "stabilityai/stable-cascade",
filename = "effnet_encoder.safetensors",
))
))
self.model.eval().to(self.dtype).to(self.device)
class CascadeA_VAE():
scale = 1/4
channels = 4
def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
self.device = device
self.dtype = dtype
# not sure if this will change in the future?
from diffusers.pipelines.wuerstchen.modeling_paella_vq_model import PaellaVQModel
self.model = PaellaVQModel.from_pretrained(
"stabilityai/stable-cascade",
subfolder="vqgan"
)
self.model.eval().to(self.dtype).to(self.device)
if dec_only:
del self.model.encoder
def encode(self, image):
image = image.to(self.dtype).to(self.device)
with torch.no_grad():
latent = self.model.encode(image).latents
return latent.to(image.dtype).to(image.device)
def decode(self, latent, grad=False):
latent = latent.to(self.dtype).to(self.device)
if grad:
out = self.model.decode(latent)[0]
else:
with torch.no_grad():
out = self.model.decode(latent).sample
out = torch.clamp(out, min=0.0, max=1.0)
return out.to(latent.dtype).to(latent.device)
class No_VAE():
scale = 1
channels = 3
def __init__(self, *args, **kwargs):
pass
def encode(self, image):
return image
def decode(self, image):
return image
vae_vers = {
"no": No_VAE,
"v1": SDv1_VAE,
"xl": SDXL_VAE,
"v3": SDv3_VAE,
"cc": CascadeC_VAE,
"ca": CascadeA_VAE,
"fx": FLUX_VAE,
}
def load_vae(ver, *args, **kwargs):
assert ver in vae_vers.keys(), f"Unknown VAE '{ver}'"
vae_class = vae_vers[ver]
return vae_class(*args, **kwargs)