-
Notifications
You must be signed in to change notification settings - Fork 4
/
setup.py
129 lines (102 loc) · 3.96 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import datasets
import generators
import predictors
import torch
import trainer
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
predictor_dict = {
'half_lenet': predictors.HalfLeNet,
'inceptionv3': predictors.InceptionV3,
'lenet': predictors.LeNet,
'resnet18': predictors.ResNet18,
'vgg': predictors.VGG,
'alexnet': predictors.Alexnet,
'half_alexnet': predictors.HalfAlexnet
}
dataset_dict = {
'cifar10': datasets.CIFAR10,
# 'discrepancy': datasets.Discrepancy,
# 'discrepancy_kl': datasets.Discrepancy_KL,
# 'curriculum': datasets.Curriculum,
'fmnist': datasets.FMNIST,
'optimized': datasets.OptimizedFromGenerator,
'random': datasets.RandomFromGenerator,
'split_fmnist': datasets.SplitFMNIST,
# 'two_gans': datasets.TwoGANs,
}
generator_dict = {
'cifar_10_gan': generators.SNGAN,
'cifar_100_90_classes_gan': generators.SNGAN,
'cifar_100_40_classes_gan': generators.SNGAN,
'cifar_10_vae': generators.VAE,
'cifar_100_6_classes_gan': generators.Progan,
'cifar_100_10_classes_gan': generators.Progan,
}
generator_prepare_dict = {
'cifar_10_gan': trainer.train_or_restore_cifar_10_gan,
'cifar_100_90_classes_gan': trainer.train_or_restore_cifar_100_90_classes_gan,
'cifar_100_40_classes_gan': trainer.train_or_restore_cifar_100_40_classes_gan,
'cifar_10_vae': trainer.train_or_restore_cifar_10_vae,
'cifar_100_6_classes_gan': trainer.train_or_restore_cifar_100_6_classes_gan,
'cifar_100_10_classes_gan': trainer.train_or_restore_cifar_100_10_classes_gan,
}
def prepare_teacher_student(env):
true_dataset = dataset_dict[env.true_dataset](input_size = env.size)
teacher = predictor_dict[env.teacher](
name = teacher_name(env),
n_outputs = true_dataset.n_classes
)
teacher.to(device)
if env.optim == 'sgd':
trainer.train_or_restore_predictor(teacher, true_dataset)
else:
trainer.train_or_restore_predictor_adam(teacher, true_dataset)
teacher.eval()
student = predictor_dict[env.student](
name = student_name(env),
n_outputs = true_dataset.n_classes
)
student.to(device)
return teacher, true_dataset, student
def prepare_generator(env):
if env.generator == 'combined':
vae = generator_dict['dcgan']()
vae = generator_prepare_dict['dcgan'](vae)
gan = generator_dict['gan']()
gan = generator_prepare_dict['gan'](gan)
class CombinedGenerator():
def __init__(self, vae, gan):
self.vae = vae
self.gan = gan
self.current_generator = self.gan
self.current_state = 'gan'
def __call__(self, inputs):
return self.current_generator(inputs)
def switch(self):
self.current_generator = (
self.gan if self.current_state == 'vae' else self.vae
)
self.current_state = (
'gan' if self.current_state == 'vae' else 'vae'
)
def encoding_size(self):
return 128 if 'gan' in self.current_state else 100
return CombinedGenerator(vae, gan)
generator = generator_dict[env.generator]()
generator = generator_prepare_dict[env.generator](generator)
return generator
def prepare_student_dataset(env, teacher, teacher_dataset, student, generator):
dataset = dataset_dict[env.samples](
generator, teacher, student,
test_dataloader = teacher_dataset.test_dataloader,
to_grayscale = ('gan' in env.generator and 'fmnist' in env.true_dataset)
)
return dataset
def teacher_name(env):
return f'teacher_{env.teacher}_for_{env.true_dataset}'
def student_name(env):
return (
f'student_{env.student}_for_teacher_{env.teacher}_true_{env.true_dataset}_' +
f'generator_{env.generator}_' +
f'samples_{env.samples}_optim_{env.optim}_epochs_{env.epochs}'
)