-
Notifications
You must be signed in to change notification settings - Fork 0
/
wgangp_cifar10_image_generation.py
140 lines (111 loc) · 5.07 KB
/
wgangp_cifar10_image_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from wgangp import WGANGP
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
from utils import save_model, load_yaml
# Set the configuration
config = load_yaml("./config/wgangp_cifar10_config.yml")
# Training setting
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(config['data']['seed'])
if device == 'cuda':
torch.cuda.manual_seed_all(config['data']['seed'])
# Set the transform
transform = transforms.Compose([transforms.ToTensor(),
transforms.Resize(config['data']['img_size'])])
# Set the training data
train_data = datasets.CIFAR10(config['data']['data_path'], download=config['data']['download'], train=True, transform=transform)
# Split the horse data
train_data = torch.utils.data.Subset(train_data, np.where(np.array(train_data.targets) == 7)[0])
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config['data']['batch_size'], shuffle=config['data']['shuffle'], drop_last=config['data']['drop_last'])
# Set the model
model = WGANGP(gen_latent_z=config['model']['gen_latent_z'], gen_init_layer=config['model']['gen_init_layer'],
gen_conv_trans=config['model']['gen_conv_trans'], gen_conv_filters=config['model']['gen_conv_filters'],
gen_conv_kernels=config['model']['gen_conv_kernels'], gen_conv_strides=config['model']['gen_conv_strides'],
gen_conv_pads=config['model']['gen_conv_pads'],gen_dropout_rate=config['model']['gen_dropout_rate'],
crt_input_img=config['model']['crt_input_img'], crt_conv_filters=config['model']['crt_conv_filters'],
crt_conv_kernels=config['model']['crt_conv_kernels'], crt_conv_strides=config['model']['crt_conv_strides'],
crt_conv_pads=config['model']['crt_conv_pads'], crt_dropout_rate=config['model']['crt_dropout_rate']).to(device)
print(model, device)
# Set the criterion and optimizer
g_optimizer = optim.Adam(model.G.parameters(),
lr=config['train']['lr'],
betas=config['train']['betas'])
c_optimizer = optim.Adam(model.C.parameters(),
lr=config['train']['lr'],
betas=config['train']['betas'])
criterion = nn.BCELoss()
# Set values
batch_size = config['data']['batch_size']
z_latent = config['model']['gen_latent_z']
gen_iteration = config ['train']['gen_iteration']
# Training
def train(epoch, train_loader, g_optimizer, c_optimizer):
model.train()
g_train_loss = 0.0
g_train_num = 0
c_train_loss = 0.0
c_train_num = 0
for i, data in enumerate(train_loader, 0):
# Critic
# get the inputs; data is a list of [inputs, labels]
real_img, _ = data
# Transfer data to device
real_img = real_img.to(device)
real_score = model.C(real_img)
# Generate generated image
z = 2 * torch.rand(batch_size, z_latent, device=device) - 1
fake_img = model.G(z)
fake_score = model.C(fake_img)
# Loss for the critic with EM distance
em_loss = fake_score.mean() - real_score.mean()
# Gradient penalty
# Make interpolated images
alpha = torch.randn(batch_size, 1, 1, 1, device=device)
interpolated_img = (alpha * real_img + ((1 - alpha) * fake_img)).requires_grad_(True)
interpolated_score = model.C(interpolated_img)
# Calculate gradients
grad_outputs = torch.ones(batch_size, 1, device=device)
gradients = torch.autograd.grad(outputs=interpolated_score, inputs=interpolated_img,
grad_outputs=grad_outputs, create_graph=True, retain_graph=True)[0]
gradients = gradients.view(batch_size, -1)
gp_loss = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
# Training for the critic
c_loss = em_loss + gp_loss
c_optimizer.zero_grad()
c_loss.backward()
c_optimizer.step()
# Generator
if i % gen_iteration == 0:
# Get the fake images and scores
z = 2 * torch.rand(batch_size, z_latent, device=device) - 1
fake_img = model.G(z)
fake_score = model.C(fake_img)
# Training for the generator
g_loss = - fake_score.mean()
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
# loss
g_train_loss += g_loss.item()
g_train_num += fake_img.size(0)
# loss
c_train_loss += c_loss.item()
c_train_num += real_img.size(0)
if i % config['others']['log_period'] == 0 and i != 0:
print(f'[{epoch}, {i}]\t Train loss: (G){g_train_loss / g_train_num:.10f}, (D){c_train_loss / c_train_num:.10f}')
# Average loss
c_train_loss /= c_train_num
return c_train_loss
# Main
if __name__ == '__main__':
for epoch in range(config['train']['epochs']): # loop over the dataset multiple times
# Training
train_loss = train(epoch, train_loader, g_optimizer, c_optimizer)
# Print the log
print(f'Epoch: {epoch}\t Train loss: {train_loss:.10f}')
# Save the model
save_model(model_name=config['save']['model_name'], epoch=epoch, model=model, optimizer=g_optimizer, loss=train_loss, config=config)