-
Notifications
You must be signed in to change notification settings - Fork 322
/
Copy pathbasic_vae_module.py
221 lines (173 loc) · 7 KB
/
basic_vae_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import urllib.parse
from argparse import ArgumentParser
import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from torch import nn
from torch.nn import functional as F
from pl_bolts import _HTTPS_AWS_HUB
from pl_bolts.models.autoencoders.components import (
resnet18_decoder,
resnet18_encoder,
resnet50_decoder,
resnet50_encoder,
)
class VAE(LightningModule):
"""Standard VAE with Gaussian Prior and approx posterior.
Model is available pretrained on different datasets:
Example::
# not pretrained
vae = VAE()
# pretrained on cifar10
vae = VAE(input_height=32).from_pretrained('cifar10-resnet18')
# pretrained on stl10
vae = VAE(input_height=32).from_pretrained('stl10-resnet18')
"""
pretrained_urls = {
"cifar10-resnet18": urllib.parse.urljoin(_HTTPS_AWS_HUB, "vae/vae-cifar10/checkpoints/epoch%3D89.ckpt"),
"stl10-resnet18": urllib.parse.urljoin(_HTTPS_AWS_HUB, "vae/vae-stl10/checkpoints/epoch%3D89.ckpt"),
}
def __init__(
self,
input_height: int,
enc_type: str = "resnet18",
first_conv: bool = False,
maxpool1: bool = False,
enc_out_dim: int = 512,
kl_coeff: float = 0.1,
latent_dim: int = 256,
lr: float = 1e-4,
**kwargs,
):
"""
Args:
input_height: height of the images
enc_type: option between resnet18 or resnet50
first_conv: use standard kernel_size 7, stride 2 at start or
replace it with kernel_size 3, stride 1 conv
maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2
enc_out_dim: set according to the out_channel count of
encoder used (512 for resnet18, 2048 for resnet50)
kl_coeff: coefficient for kl term of the loss
latent_dim: dim of latent space
lr: learning rate for Adam
"""
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.kl_coeff = kl_coeff
self.enc_out_dim = enc_out_dim
self.latent_dim = latent_dim
self.input_height = input_height
valid_encoders = {
"resnet18": {
"enc": resnet18_encoder,
"dec": resnet18_decoder,
},
"resnet50": {
"enc": resnet50_encoder,
"dec": resnet50_decoder,
},
}
if enc_type not in valid_encoders:
self.encoder = resnet18_encoder(first_conv, maxpool1)
self.decoder = resnet18_decoder(self.latent_dim, self.input_height, first_conv, maxpool1)
else:
self.encoder = valid_encoders[enc_type]["enc"](first_conv, maxpool1)
self.decoder = valid_encoders[enc_type]["dec"](self.latent_dim, self.input_height, first_conv, maxpool1)
self.fc_mu = nn.Linear(self.enc_out_dim, self.latent_dim)
self.fc_var = nn.Linear(self.enc_out_dim, self.latent_dim)
@staticmethod
def pretrained_weights_available():
return list(VAE.pretrained_urls.keys())
def from_pretrained(self, checkpoint_name):
if checkpoint_name not in VAE.pretrained_urls:
raise KeyError(str(checkpoint_name) + " not present in pretrained weights.")
return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False)
def forward(self, x):
x = self.encoder(x)
mu = self.fc_mu(x)
log_var = self.fc_var(x)
p, q, z = self.sample(mu, log_var)
return self.decoder(z)
def _run_step(self, x):
x = self.encoder(x)
mu = self.fc_mu(x)
log_var = self.fc_var(x)
p, q, z = self.sample(mu, log_var)
return z, self.decoder(z), p, q
def sample(self, mu, log_var):
std = torch.exp(log_var / 2)
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
q = torch.distributions.Normal(mu, std)
z = q.rsample()
return p, q, z
def step(self, batch, batch_idx):
x, y = batch
z, x_hat, p, q = self._run_step(x)
recon_loss = F.mse_loss(x_hat, x, reduction="mean")
kl = torch.distributions.kl_divergence(q, p)
kl = kl.mean()
kl *= self.kl_coeff
loss = kl + recon_loss
logs = {
"recon_loss": recon_loss,
"kl": kl,
"loss": loss,
}
return loss, logs
def training_step(self, batch, batch_idx):
loss, logs = self.step(batch, batch_idx)
self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
return loss
def validation_step(self, batch, batch_idx):
loss, logs = self.step(batch, batch_idx)
self.log_dict({f"val_{k}": v for k, v in logs.items()})
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--enc_type", type=str, default="resnet18", help="resnet18/resnet50")
parser.add_argument("--first_conv", action="store_true")
parser.add_argument("--maxpool1", action="store_true")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument(
"--enc_out_dim",
type=int,
default=512,
help="512 for resnet18, 2048 for bigger resnets, adjust for wider resnets",
)
parser.add_argument("--kl_coeff", type=float, default=0.1)
parser.add_argument("--latent_dim", type=int, default=256)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--data_dir", type=str, default=".")
return parser
def cli_main(args=None):
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
seed_everything()
parser = ArgumentParser()
parser.add_argument("--dataset", default="cifar10", type=str, choices=["cifar10", "stl10", "imagenet"])
script_args, _ = parser.parse_known_args(args)
if script_args.dataset == "cifar10":
dm_cls = CIFAR10DataModule
elif script_args.dataset == "stl10":
dm_cls = STL10DataModule
elif script_args.dataset == "imagenet":
dm_cls = ImagenetDataModule
else:
raise ValueError(f"undefined dataset {script_args.dataset}")
parser = VAE.add_model_specific_args(parser)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args(args)
dm = dm_cls.from_argparse_args(args)
args.input_height = dm.size()[-1]
if args.max_steps == -1:
args.max_steps = None
model = VAE(**vars(args))
trainer = Trainer.from_argparse_args(args)
trainer.fit(model, datamodule=dm)
return dm, model, trainer
if __name__ == "__main__":
dm, model, trainer = cli_main()