-
Notifications
You must be signed in to change notification settings - Fork 85
/
run.py
96 lines (80 loc) · 3.89 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Trains MADE on Binarized MNIST, which can be downloaded here:
https://github.com/mgermain/MADE/releases/download/ICML2015/binarized_mnist.npz
"""
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from made import MADE
# ------------------------------------------------------------------------------
def run_epoch(split, upto=None):
torch.set_grad_enabled(split=='train') # enable/disable grad for efficiency of forwarding test batches
model.train() if split == 'train' else model.eval()
nsamples = 1 if split == 'train' else args.samples
x = xtr if split == 'train' else xte
N,D = x.size()
B = 100 # batch size
nsteps = N//B if upto is None else min(N//B, upto)
lossfs = []
for step in range(nsteps):
# fetch the next batch of data
xb = Variable(x[step*B:step*B+B])
# get the logits, potentially run the same batch a number of times, resampling each time
xbhat = torch.zeros_like(xb)
for s in range(nsamples):
# perform order/connectivity-agnostic training by resampling the masks
if step % args.resample_every == 0 or split == 'test': # if in test, cycle masks every time
model.update_masks()
# forward the model
xbhat += model(xb)
xbhat /= nsamples
# evaluate the binary cross entropy loss
loss = F.binary_cross_entropy_with_logits(xbhat, xb, size_average=False) / B
lossf = loss.data.item()
lossfs.append(lossf)
# backward/update
if split == 'train':
opt.zero_grad()
loss.backward()
opt.step()
print("%s epoch average loss: %f" % (split, np.mean(lossfs)))
# ------------------------------------------------------------------------------
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data-path', required=True, type=str, help="Path to binarized_mnist.npz")
parser.add_argument('-q', '--hiddens', type=str, default='500', help="Comma separated sizes for hidden layers, e.g. 500, or 500,500")
parser.add_argument('-n', '--num-masks', type=int, default=1, help="Number of orderings for order/connection-agnostic training")
parser.add_argument('-r', '--resample-every', type=int, default=20, help="For efficiency we can choose to resample orders/masks only once every this many steps")
parser.add_argument('-s', '--samples', type=int, default=1, help="How many samples of connectivity/masks to average logits over during inference")
args = parser.parse_args()
# --------------------------------------------------------------------------
# reproducibility is good
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
# load the dataset
print("loading binarized mnist from", args.data_path)
mnist = np.load(args.data_path)
xtr, xte = mnist['train_data'], mnist['valid_data']
xtr = torch.from_numpy(xtr).cuda()
xte = torch.from_numpy(xte).cuda()
# construct model and ship to GPU
hidden_list = list(map(int, args.hiddens.split(',')))
model = MADE(xtr.size(1), hidden_list, xtr.size(1), num_masks=args.num_masks)
print("number of model parameters:",sum([np.prod(p.size()) for p in model.parameters()]))
model.cuda()
# set up the optimizer
opt = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=45, gamma=0.1)
# start the training
for epoch in range(100):
print("epoch %d" % (epoch, ))
scheduler.step(epoch)
run_epoch('test', upto=5) # run only a few batches for approximate test accuracy
run_epoch('train')
print("optimization done. full test set eval:")
run_epoch('test')