-
Notifications
You must be signed in to change notification settings - Fork 0
/
lgan.py
89 lines (66 loc) · 2.66 KB
/
lgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# ml tools
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
# logging tools
from utils import Logger
# data tools
import data
# network models
from networks import LinearDiscriminator, LinearGenerator
# optimizers
from optimizers import linear_adam_optimizer
# training steps
from train import train_discriminator, train_generator
# get the data loader
# Load data
mnist = data.mnist_data()
data_loader = data.get_data_loader(mnist)
# get the number of batches from the data loader
num_batches = len(data_loader)
# create the discriminator model + optimizerr
discriminator = LinearDiscriminator()
d_optimizer = linear_adam_optimizer(discriminator.parameters(), 0.0002)
# create the generator model + optimizer
generator = LinearGenerator()
g_optimizer = linear_adam_optimizer(generator.parameters(), 0.0002)
# convert to gpu if available
if torch.cuda.is_available():
generator.cuda()
discriminator.cuda()
# create loss function
loss = nn.BCELoss()
# establish some test noise to feed to the generator throughout training
num_test_samples = 16
test_noise = data.noise(num_test_samples)
# train
def train():
logger = Logger(model_name="VGAN", data_name="MNIST")
num_epochs = 200
for epoch in range(num_epochs):
for n_batch, (real_batch,_) in enumerate(data_loader):
N = real_batch.size(0)
# TRAIN DISCRIMINATOR
# generate real data from data loader
real_data = Variable(data.images_to_vectors(real_batch, 784))
if torch.cuda.is_available(): real_data = real_data.cuda()
# generate fake data and detach gradient
fake_data = generator(data.noise(N)).detach()
# train discriminator
d_error, d_pred_real, d_pred_fake = train_discriminator(discriminator, loss, d_optimizer, real_data, fake_data)
# TRAIN GENERATOR
# generate fake data
fake_data = generator(data.noise(N))
# train generator
g_error = train_generator(discriminator, loss, g_optimizer, fake_data)
# LOG BATCH ERROR
logger.log(d_error, g_error, epoch, n_batch, num_batches)
# Display progress every few batches
if (n_batch) % 100 == 0:
test_images = data.vectors_to_images(generator(test_noise))
test_images = test_images.data
logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches)
logger.display_status(epoch, num_epochs, n_batch, num_batches, d_error, g_error, d_pred_real, d_pred_fake)
if __name__ == "__main__":
train()