-
Notifications
You must be signed in to change notification settings - Fork 12
/
profTeacher_full.yaml
103 lines (87 loc) · 2.8 KB
/
profTeacher_full.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# @package _global_
# to execute this experiment run:
# python examples/1_bias_bios_full.py
seed: 3
ignore_warnings: True
datamodule:
_target_: examples.datamodules.ProfTeacher_datamodule.ProfTeacher_DataModule
# The following three params are ignored by the datamodule, but semantically are tied to the dataset,
# which is why they are included here. In practice though, only Weasel needs to know them, see below in the Weasel params
# where they are all automatically inferred from these params in datamodule.
num_LFs: 99
n_classes: 2
class_balance: [0.5, 0.5]
# Actual params used by DataModule
batch_size: 64
val_test_split: [250, -1]
num_workers: 0
pin_memory: True
seed: 3
trainer:
_target_: pytorch_lightning.Trainer
devices: "auto" # number of CPUs or GPUs to use
accelerator: "auto" # "ddp" for distributed training
min_epochs: 1
max_epochs: 100
# resume_from_checkpoint: ${work_dir}/last.ckpt
end_model:
_target_: weasel.models.downstream_models.MLP.MLPNet
dropout: 0.3
net_norm: "none"
activation_func: "ReLU"
input_dim: 300
hidden_dims: [50, 50, 25]
output_dim: ${datamodule.n_classes}
adjust_thresh: True
Weasel:
_target_: weasel.models.weasel.Weasel
# very convenient interpolation by Hydra/OmegaConf that infers params from the dataset details
num_LFs: ${datamodule.num_LFs}
n_classes: ${datamodule.n_classes}
class_balance: ${datamodule.class_balance}
loss_function: "cross_entropy"
temperature: 2.0
accuracy_scaler: "sqrt"
use_aux_input_for_encoder: True
class_conditional_accuracies: True
encoder:
_target_: weasel.models.encoder_models.encoder_MLP.MLPEncoder
# example_extra_input: [1, 300] # OR (1, ${end_model.input_dim})
# |--> Commented out, since it is better to specify the 'example_input_array' attribute in your end-model
dropout: 0.3
net_norm: "batch_norm"
activation_func: "ReLU"
hidden_dims: [70, 70]
optim_end_model:
_target_: torch.optim.Adam
lr: 1e-4
weight_decay: 7e-7
optim_encoder:
_target_: torch.optim.Adam
lr: 1e-4
weight_decay: 7e-7
scheduler: null
verbose: True
callbacks:
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: "Val/auc"
save_top_k: 2
save_last: True
mode: "max"
dirpath: "checkpoints/"
filename: "example-profTeacher-{epoch:02d}"
watch_model: # log histograms of model gradients to WandB, todo: NOT WORKING
#_target_: weasel.callbacks.wandb_callbacks.WatchModel
log: "gradients"
log_freq: 100
logger:
wandb:
_target_: pytorch_lightning.loggers.wandb.WandbLogger
# entity: "" # optionally set to name of your wandb team
name: null
tags: []
notes: "..."
project: "Weasel_ProfTeacher"
group: ""
mode: online # disabled # disabled for no wandb logging