-
Notifications
You must be signed in to change notification settings - Fork 3
/
gen_gmm.py
123 lines (83 loc) · 3.56 KB
/
gen_gmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torch.nn as nn
import utils.models as models
import utils.dataloaders as dl
from sklearn import mixture
import numpy as np
import utils.gmm_helpers as gmm_helpers
import model_params
import torchvision.transforms as trn
import argparse
parser = argparse.ArgumentParser(description='Define hyperparameters.')
#parser.add_argument('--n', type=int, default=100, help='number of Gaussians.')
parser.add_argument('--dataset', type=str, default='MNIST', help='MNIST, SVHN, CIFAR10, CIFAR100.')
parser.add_argument('--data_used', type=int, default=None, help='number of datapoints to be used.')
parser.add_argument('--augm_flag', type=bool, default=False, help='whether to use data augmentation.')
parser.add_argument('--PCA', type=bool, default=False, help='initialize for using in PCA metric.')
parser.add_argument('--n', nargs='+', type=int, default=[100])
hps = parser.parse_args()
params = model_params.params_dict[hps.dataset](augm_flag=hps.augm_flag)
dim = params.dim
loader = params.train_loader
hps.data_used = params.data_used if hps.data_used is None else hps.data_used
X = []
for x, f in loader:
X.append(x.view(-1,dim))
X = torch.cat(X, 0)
X = X[:hps.data_used] #needed to keep memory of distance matrix below 800 GB
if hps.PCA:
metric = models.PCAMetric( X, p=2, min_sv_factor=1e6)
X = ( ([email protected]_vecs.t()) / metric.singular_values_sqrt[None,:] )
else:
metric = models.LpMetric()
for n in hps.n:
print(n)
gmm = models.GMM(n, dim, metric=metric)
clf = mixture.GMM(n_components=n, covariance_type='spherical', params='mc')
clf.fit(X)
mu = torch.tensor(clf.means_ ,dtype=torch.float)
logvar = torch.tensor(np.log(clf.covars_[:,0]) ,dtype=torch.float)
logvar = 0.*logvar + logvar.exp().mean().log()
alpha = torch.tensor(np.log(clf.weights_) ,dtype=torch.float)
gmm = models.GMM(n, dim, mu=mu, logvar=logvar, metric=metric)
if hps.PCA:
gmm.mu.data = ( (gmm.mu.data * metric.singular_values_sqrt[None,:] )
@ metric.comp_vecs.t().inverse() )
saving_string = ('SavedModels/GMM/gmm_' + hps.dataset
+'_n' + str(n)
+'_data_used' + str(hps.data_used)
+'_augm_flag' + str(hps.augm_flag))
if hps.PCA:
saving_string += '_PCA'
torch.save(gmm, saving_string + '.pth')
out_loader = dl.TinyImages(hps.dataset)
X = []
for idx, (x, f) in enumerate(out_loader):
if idx>200:
break;
X.append(x.view(-1,dim))
X = torch.cat(X, 0)
if hps.PCA:
X = ( ([email protected]_vecs.t()) / metric.singular_values_sqrt[None,:] )
for n in hps.n:
print(n)
# Out GMM
gmm = models.GMM(n, dim, metric=metric)
clf = mixture.GMM(n_components=n, covariance_type='spherical', params='mc')
clf.fit(X)
mu = torch.tensor(clf.means_ ,dtype=torch.float)
logvar = torch.tensor(np.log(clf.covars_[:,0]) ,dtype=torch.float)
logvar = 0.*logvar + logvar.exp().mean().log()
alpha = torch.tensor(np.log(clf.weights_) ,dtype=torch.float)
gmm = models.GMM(n, dim, mu=mu, logvar=logvar, metric=metric)
if hps.PCA:
gmm.mu.data = ( (gmm.mu.data * metric.singular_values_sqrt[None,:] )
@ metric.comp_vecs.t().inverse() )
saving_string = ('SavedModels/GMM/gmm_' + hps.dataset
+'_n' + str(n)
+'_data_used' + str(hps.data_used)
+'_augm_flag' + str(hps.augm_flag))
if hps.PCA:
saving_string += '_PCA'
torch.save(gmm, saving_string + '_OUT' + '.pth')
print('Done')