-
Notifications
You must be signed in to change notification settings - Fork 2
/
20way1shot.yaml
43 lines (39 loc) · 1.05 KB
/
20way1shot.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
name: omniglot_protonet
data:
data_path: ./data
validation_portion: 0.25
tasks_per_epoch_train: 100
tasks_per_epoch_val: 1000
train_workers: 8
val_workers: 4
hyperparameters:
learning_rate: 1.0e-3
weight_decay: 0
reduce_every: 200
lr_gamma: 0.5
global_batch_size: 2 # how many tasks to train before performing a meta-update
val_batch_size: 2 # how many tasks to evaluate on
# Meta-training
num_classes_train: 60
num_support_train: 1
num_query_train: 5
# Meta-test
num_classes_val: 20 #n-way
num_support_val: 1 #k-shot
# Model
img_resize_dim: 28 # input will be 1 x img_resize_dim x img_resize_dim
hidden_dim: 64 # intermediate number of channels
embedding_dim: 64 # embedding number of channels
resources:
slots_per_trial: 2
searcher:
name: single
metric: loss
smaller_is_better: true
# Original paper trained for 10,000 epochs with a plateau stopping condition
max_length:
batches: 30000
entrypoint: model_def:OmniglotProtoNetTrial
min_validation_period:
batches: 5000
checkpoint_policy: none