-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_vqgan.py
236 lines (200 loc) · 8.52 KB
/
train_vqgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# file for running the training of the VQGAN
import torch
import numpy as np
import copy
import time
import random
from torchvision.transforms.functional import hflip
from models.vqgan import VQGAN
from hparams import get_vqgan_hparams
from utils.data_utils import get_data_loaders, cycle
from utils.train_utils import EMA
from utils.log_utils import log, log_stats, save_model, save_stats, save_images, \
display_images, set_up_visdom, config_log, start_training_log
from utils.vqgan_utils import load_vqgan_from_checkpoint, calc_FID
torch.backends.cudnn.benchmark = True
def main(H, vis):
vqgan = VQGAN(H).cuda()
# only load val_loader if running eval
train_loader, val_loader = get_data_loaders(
H.dataset,
H.img_size,
H.batch_size,
get_val_dataloader=(H.steps_per_eval != 0)
)
train_iterator = cycle(train_loader)
if val_loader is not None:
val_iterator = cycle(val_loader)
if H.ema:
ema = EMA(H.ema_beta)
ema_vqgan = copy.deepcopy(vqgan)
else:
ema_vqgan = None
optim = torch.optim.Adam(vqgan.ae.parameters(), lr=H.lr)
d_optim = torch.optim.Adam(vqgan.disc.parameters(), lr=H.lr)
if H.amp:
scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
losses = np.array([])
mean_losses = np.array([])
val_losses = np.array([])
recon_losses = np.array([])
latent_ids = []
fids = np.array([])
best_fid = float('inf')
# NOTE this is getting messy now - easier to just build up a list of steps I think
start_step = 0
log_start_step = 0
eval_start_step = H.steps_per_eval
if H.load_step > 0:
start_step = H.load_step + 1 # don't repeat the checkpointed step
vqgan, optim, d_optim, ema_vqgan, train_stats = load_vqgan_from_checkpoint(H, vqgan, optim, d_optim, ema_vqgan)
# stats won't load for old models with no associated stats file
if train_stats is not None:
losses = train_stats["losses"]
mean_losses = train_stats["mean_losses"]
val_losses = train_stats["val_losses"]
latent_ids = train_stats["latent_ids"]
fids = train_stats["fids"]
best_fid = train_stats["best_fid"]
H.steps_per_log = train_stats["steps_per_log"]
H.steps_per_eval = train_stats["steps_per_eval"]
log_start_step = 0
eval_start_step = H.steps_per_eval
log('Loaded stats')
else:
log_start_step = start_step
if H.steps_per_eval:
if H.steps_per_eval == 1:
eval_start_step = start_step
else:
# set eval_start_step to nearest next eval point
# NOTE could just define an array to hold steps the whole time and add as it goes?
# more/less efficient? does it matter?
# would have to regenerate steps list again anyway
eval_start_step = start_step + H.steps_per_eval - start_step % H.steps_per_eval
steps_per_epoch = len(train_loader)
log(f'Epoch length: {steps_per_epoch}')
log(f"ae params: {sum(p.numel() for p in vqgan.ae.parameters())}")
log(f"disc params:{sum(p.numel() for p in vqgan.disc.parameters())}")
log(f"total params:{sum(p.numel() for p in vqgan.ae.parameters()) + sum(p.numel() for p in vqgan.disc.parameters())}")
for step in range(start_step, H.train_steps):
step_start_time = time.time()
batch = next(train_iterator)
if isinstance(batch, list):
x = batch[0]
else:
x = batch
if H.horizontal_flip:
if random.random() <= 0.5:
x = hflip(x)
x = x.cuda()
if H.amp:
optim.zero_grad()
with torch.cuda.amp.autocast():
x_hat, stats = vqgan.train_iter(x, step)
scaler.scale(stats['loss']).backward()
scaler.step(optim)
scaler.update()
else:
x_hat, stats = vqgan.train_iter(x, step)
optim.zero_grad()
stats['loss'].backward()
optim.step()
losses = np.append(losses, stats['loss'].item())
if step > H.disc_start_step:
if H.amp:
d_optim.zero_grad()
d_scaler.scale(stats['d_loss']).backward()
d_scaler.step(d_optim)
d_scaler.update()
else:
d_optim.zero_grad()
stats['d_loss'].backward()
d_optim.step()
# collect latent ids
latent_ids.append(stats['latent_ids'].cpu().contiguous())
if step % H.steps_per_log == 0:
mean_loss = np.mean(losses)
stats['loss'] = mean_loss
step_time = time.time() - step_start_time
stats['step_time'] = step_time
mean_losses = np.append(mean_losses, mean_loss)
recon_losses = np.append(recon_losses, stats['l1'])
losses = np.array([])
vis.line(
mean_losses,
list(range(log_start_step, step+1, H.steps_per_log)),
win='loss',
opts=dict(title='Loss')
)
# vis.line(
# recon_losses,
# list(range(log_start_step, step+1, H.steps_per_log)),
# win='recon_loss',
# opts=dict(title='Train L1 Loss')
# )
log_stats(step, stats)
# bundled validation loss and FID calculations together
# NOTE put in seperate function?
if H.steps_per_eval:
if step % H.steps_per_eval == 0 and step > 0:
# log('Evaluating FIDs and validation loss:')
# vqgan.eval()
# # Calc FIDs
# fid = calc_FID(H, ema_vqgan if H.ema else vqgan)
# fids = np.append(fids, fid)
# log(f'FID: {fid}')
# if fid < best_fid:
# save_model(ema_vqgan if H.ema else vqgan, 'vqgan_bestfid', step, H.log_dir)
# Calc validation losses
x_val = next(val_iterator)
if H.deepspeed:
x_val = x_val.half()
_, val_stats = vqgan.val_iter(x, step)
val_losses = np.append(val_losses, val_stats['l1'])
steps = [step for step in range(eval_start_step, step+1, H.steps_per_eval)]
# vis.line(fids, steps, win='FID',opts=dict(title='FID'))
vis.line(val_losses, steps, win='val', opts=dict(title='Validation L1 Loss'))
# vqgan.train()
# log codebook usage
if step % steps_per_epoch == 0 and step > 0:
latent_ids = torch.cat(latent_ids, dim=0)
unique_code_ids = torch.unique(latent_ids).to(dtype=torch.int64)
log(f'Codebook size: {H.codebook_size} Unique Codes Used in Epoch: {len(unique_code_ids)}')
latent_ids = []
if H.ema and step % H.steps_per_update_ema == 0 and step > 0:
ema.update_model_average(ema_vqgan, vqgan)
if step % H.steps_per_display_output == 0 and step > 0:
display_images(vis, x, H, 'Original Images')
# if H.ema:
# x_hat, _ = ema_vqgan.train_iter(x, step)
x_hat = x_hat.detach().cpu().to(torch.float32)
display_images(vis, x_hat, H, 'VQGAN Recons')
if step % H.steps_per_save_output == 0:
save_images(x_hat, 'recons', step, H.log_dir, H.save_individually)
if step % H.steps_per_checkpoint == 0 and step > H.load_step:
save_model(vqgan, 'vqgan', step, H.log_dir)
save_model(optim, 'ae_optim', step, H.log_dir)
save_model(d_optim, 'disc_optim', step, H.log_dir)
if H.ema:
save_model(ema_vqgan, 'vqgan_ema', step, H.log_dir)
train_stats = {
'losses': losses,
'mean_losses': mean_losses,
'val_losses': val_losses,
'latent_ids': latent_ids,
'fids': fids,
'best_fid': best_fid,
'steps_per_log': H.steps_per_log,
'steps_per_eval': H.steps_per_eval,
}
save_stats(H, train_stats, step)
if __name__ == '__main__':
H = get_vqgan_hparams()
vis = set_up_visdom(H)
config_log(H.log_dir)
log('---------------------------------')
log(f'Setting up training for VQGAN on {H.dataset}')
start_training_log(H)
main(H, vis)