-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtrain_vqvae.py
99 lines (86 loc) · 4.38 KB
/
train_vqvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import os
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
from tqdm import tqdm
from networks import VQVAE
from utilities import ChestXrayHDF5, recon_image, save_loss_plots
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
parser = argparse.ArgumentParser()
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--n_epochs', type=int, default=5600)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--first_stride', type=int, default=4, help="2, 4, 8, or 16")
parser.add_argument('--second_stride', type=int, default=2, help="2, 4, 8, or 16")
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--data_path', type=str, default='/home/aisinai/work/HDF5_datasets')
parser.add_argument('--dataset', type=str, default='CheXpert', help="CheXpert or mimic")
parser.add_argument('--view', type=str, default='frontal', help="frontal or lateral")
parser.add_argument('--save_path', type=str, default='/home/aisinai/work/VQ-VAE2/20200820/vq_vae')
parser.add_argument('--train_run', type=str, default='0')
args = parser.parse_args()
torch.manual_seed(816)
save_path = f'{args.save_path}/{args.dataset}/{args.train_run}'
os.makedirs(save_path, exist_ok=True)
os.makedirs(f'{save_path}/checkpoint/', exist_ok=True)
os.makedirs(f'{save_path}/sample/', exist_ok=True)
with open(f'{save_path}/args.txt', 'w') as f:
for key in vars(args).keys():
f.write(f'{key}: {vars(args)[key]}\n')
print(f'{key}: {vars(args)[key]}')
dataloaders = {}
dataloaders['train'] = DataLoader(ChestXrayHDF5(f'{args.data_path}/{args.dataset}_train_{args.size}_{args.view}_normalized.hdf5'),
batch_size=128,
shuffle=True,
drop_last=True)
dataloaders['valid'] = DataLoader(ChestXrayHDF5(f'{args.data_path}/{args.dataset}_valid_{args.size}_{args.view}_normalized.hdf5'),
batch_size=128,
shuffle=True,
drop_last=True)
for i, (img, targets) in enumerate(dataloaders['valid']):
sample_img = Variable(img.type(Tensor))
break
if cuda:
model = VQVAE(first_stride=args.first_stride, second_stride=args.second_stride, embed_dim=args.embed_dim).cuda()
else:
model = VQVAE(first_stride=args.first_stride, second_stride=args.second_stride, embed_dim=args.embed_dim)
n_gpu = torch.cuda.device_count()
if n_gpu > 1:
device_ids = list(range(n_gpu))
model = nn.DataParallel(model, device_ids=device_ids)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
losses = np.zeros((2, args.n_epochs, 3)) # [0,:,:] index for train, [1,:,:] index for valid
for epoch in range(args.n_epochs):
for phase in ['train', 'valid']:
model.train(phase == 'train') # True when 'train', False when 'valid'
criterion = nn.MSELoss()
latent_loss_weight = 0.25
n_row = 5
loader = tqdm(dataloaders[phase])
for i, (img, label) in enumerate(loader):
img = Variable(img.type(Tensor))
with torch.set_grad_enabled(phase == 'train'):
optimizer.zero_grad()
out, latent_loss = model(img)
recon_loss = criterion(out, img)
latent_loss = latent_loss.mean()
loss = recon_loss + latent_loss_weight * latent_loss
if phase == 'train':
loss.backward()
optimizer.step()
losses[0, epoch, :] = [loss, recon_loss, latent_loss]
else:
losses[1, epoch, :] = [loss, recon_loss, latent_loss]
lr = optimizer.param_groups[0]['lr']
loader.set_description((f'phase: {phase}; epoch: {epoch + 1}; total_loss: {loss.item():.5f}; '
f'latent: {latent_loss.item():.5f}; mse: {recon_loss.item():.5f}; '
f'lr: {lr:.5f}'))
if i % 10 == 0:
recon_image(n_row, sample_img, model, f'{save_path}', epoch, Tensor)
save_loss_plots(args.n_epochs, epoch, losses, f'{save_path}')
if (epoch + 1) % 10 == 0:
torch.save(model.state_dict(), f'{save_path}/checkpoint/vqvae_{str(epoch + 1).zfill(3)}.pt')