-
Notifications
You must be signed in to change notification settings - Fork 0
/
launcher.py
51 lines (47 loc) · 2.27 KB
/
launcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from ZapfenDataset import ZapfenDataset
from models.ffnn import FFNN1, FFNN2, FFNN3, FFNN4
from train import train
from torch.nn import ReLU, LeakyReLU, CrossEntropyLoss
from torch import tensor
from config import TRAIN_CONFIG, NN_CONFIG, DATALOADER_CONFIG, WEIGHT_LOSS_FN
import torch
import sys
def launch():
# TODO: train all with same weight initialisation
global WEIGHT_LOSS_FN
for batch_size in [4, 1, 2, 4, 8, 16]:
for act_func in [ReLU(), LeakyReLU()]:
for apply_weight in [True, False]:
for nn in [FFNN1(), FFNN2(), FFNN3(), FFNN4()]:
for lr in [0.0005, 0.001, 0.003, 0.005, 0.01, 0.03]:
for mom in [0.7, 0.75, 0.8, 0.85, 0.9]:
if apply_weight:
loss_fn = CrossEntropyLoss(weight=tensor([0.28, 0.1, 0.1, 0.52]))
else:
loss_fn = CrossEntropyLoss()
DATALOADER_CONFIG['batch_size'] = batch_size
TRAIN_CONFIG['loss_fn'] = loss_fn
TRAIN_CONFIG['lr'] = lr
TRAIN_CONFIG['momentum'] = mom
NN_CONFIG[act_func] = act_func
WEIGHT_LOSS_FN = apply_weight
if WEIGHT_LOSS_FN:
weight = 'weight_applied'
else:
weight = ''
context_str = str(batch_size) + '_' + str(loss_fn) + str(act_func) + weight
context_str = context_str.replace(')', '_').replace('(', '_')
train(nn, context_str)
sys.exit()
return
ds = ZapfenDataset('./zapfen.csv')
ds.plot_label_distribution()
return
ds.plot_feature_distribution(fname='feature_distr_before_fixing', title='before_fixing')
ds.fix_invalid_values()
ds.plot_feature_distribution(fname='feature_distr_after_fixing', title='after_fixing')
ds.scale()
ds.plot_feature_distribution(fname='after_fixing_and_normalizing', title='after_fixing_and_normalizing')
ds.plot_label_distribution(fname='label_distr_after')
if __name__ == '__main__':
launch()