-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmany_well.yaml
70 lines (60 loc) · 1.85 KB
/
many_well.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#defaults:
# - override hydra/launcher: joblib
target:
dim: 32
flow:
layer_nodes_per_dim: 10
n_layers: 10
act_norm: false
use_snf: false
resampled_base: false
snf: # config if snf is used.
transition_operator_type: hmc # metropolis or hmc
step_size: 1.0
num_steps: 5 # leapfrog steps if hmc, else metropolis accept/reject steps
it_snf_layer: 2
fab:
# loss type options: fab_alpha_div for standard FAB loss
# target_forward_kl: forward kl estimated with samples from the target
# flow_reverse_kl, flow_alpha_2_div_nis for revers KL/alpha_2_div estimated with flow samples
loss_type: fab_alpha_div
alpha: 2.0 # null
transition_operator:
type: hmc
n_inner_steps: 5
init_step_size: 1.0
n_intermediate_distributions: 4
training:
tlimit: null # time limit in hours
checkpoint_load_dir: null # or null
seed: 0
lr: 3e-4
batch_size: 2048
n_iterations: null
n_flow_forward_pass: 10_000_000_000
use_gpu: true
use_64_bit: true
use_buffer: false
prioritised_buffer: false
buffer_temp: 0.0 # rate that we weight new experience over old
n_batches_buffer_sampling: 8
maximum_buffer_length: 512000
min_buffer_length: 65536
log_w_clip_frac: null
max_grad_norm: 100.0
w_adjust_max_clip: null # clipping of weight adjustment factor for prioritised replay
evaluation:
n_plots: 50 # number of times we visualise the model throughout training.
n_eval: 50 # for calculating metrics of flow w.r.t target.
eval_batch_size: 2048 # must be a multiple of inner batch size
n_checkpoints: 10 # number of model checkpoints saved
save_path: ./results/many_well/seed${training.seed}/
logger:
pandas_logger:
save_period: 1000 # how often to save the pandas dataframe as a csv
# wandb:
# name: ManyWell32
# project: fab
# entity: flow-ais-bootstrap
# tags: [alpha_2_loss,ManyWell32]
# list_logger: