-
Notifications
You must be signed in to change notification settings - Fork 5
/
fusion.py
121 lines (105 loc) · 3.62 KB
/
fusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from config.experiments.base_experiments import *
from config.templates.models import eamc, simvc_comvc, custom
from config.templates.fusion import Concat
# ======================================================================================================================
# NoisyMNIST
# ======================================================================================================================
ablfusion_noisymnist_eamc = NoisyMNISTExperiment(
model_config=eamc.EAMC(
encoder_configs=NOISY_MNIST_ENCODERS,
attention_config=None,
fusion_config=Concat(),
loss_config=eamc.EAMCLoss(
weights=[1, 1, 1, 10, 1],
funcs="DDC1|DDC2Flipped|DDC3|EAMCGenerator|EAMCDiscriminator",
)
),
wandb_tags="ablation,fusion"
)
ablfusion_noisymnist_simvc = NoisyMNISTExperiment(
model_config=simvc_comvc.SiMVC(
encoder_configs=NOISY_MNIST_ENCODERS,
fusion_config=Concat(),
),
wandb_tags="ablation,fusion"
)
ablfusion_noisymnist_comvc = NoisyMNISTExperiment(
model_config=simvc_comvc.CoMVC(
encoder_configs=NOISY_MNIST_ENCODERS,
fusion_config=Concat(),
loss_config=simvc_comvc.CoMVCLoss(
contrast_adaptive_weight=False
)
),
wandb_tags="ablation,fusion"
)
ablfusion_noisymnist_sae = NoisyMNISTExperiment(
model_config=custom.CAE(
encoder_configs=NOISY_MNIST_ENCODERS,
decoder_configs=NOISY_MNIST_DECODERS,
loss_config=custom.CAELoss(funcs="DDC1|DDC2|DDC3|MSE"),
pre_train_loss_config=custom.CAELoss(funcs="MSE"),
projector_config=None,
fusion_config=Concat(),
),
wandb_tags="ablation,fusion"
)
ablfusion_noisymnist_cae = NoisyMNISTExperiment(
model_config=custom.CAE(
encoder_configs=NOISY_MNIST_ENCODERS,
decoder_configs=NOISY_MNIST_DECODERS,
fusion_config=Concat(),
),
wandb_tags="ablation,fusion"
)
# ======================================================================================================================
# Caltech7
# ======================================================================================================================
ablfusion_caltech7_eamc = Caltech7Experiment(
model_config=eamc.EAMC(
encoder_configs=CALTECH_ENCODERS,
attention_config=None,
fusion_config=Concat(),
loss_config=eamc.EAMCLoss(
weights=[1, 1, 1, 10, 1],
funcs="DDC1|DDC2Flipped|DDC3|EAMCGenerator|EAMCDiscriminator",
)
),
wandb_tags="ablation,fusion"
)
ablfusion_caltech7_simvc = Caltech7Experiment(
model_config=simvc_comvc.SiMVC(
encoder_configs=CALTECH_ENCODERS,
fusion_config=Concat(),
),
wandb_tags="ablation,fusion"
)
ablfusion_caltech7_comvc = Caltech7Experiment(
model_config=simvc_comvc.CoMVC(
encoder_configs=CALTECH_ENCODERS,
fusion_config=Concat(),
loss_config=simvc_comvc.CoMVCLoss(
contrast_adaptive_weight=False
)
),
wandb_tags="ablation,fusion"
)
ablfusion_caltech7_sae = Caltech7Experiment(
model_config=custom.CAE(
encoder_configs=CALTECH_ENCODERS,
decoder_configs=CALTECH_DECODERS,
loss_config=custom.CAELoss(funcs="DDC1|DDC2|DDC3|MSE"),
pre_train_loss_config=custom.CAELoss(funcs="MSE"),
projector_config=None,
fusion_config=Concat(),
),
wandb_tags="ablation,fusion"
)
ablfusion_caltech7_cae = Caltech7Experiment(
model_config=custom.CAE(
encoder_configs=CALTECH_ENCODERS,
decoder_configs=CALTECH_DECODERS,
fusion_config=Concat(),
),
wandb_tags="ablation,fusion"
)