-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconfig.yaml
102 lines (91 loc) · 4.38 KB
/
config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# lightning.pytorch==2.0.9.post0
seed_everything: 42
trainer:
accelerator: gpu
strategy: ddp_find_unused_parameters_true
devices: [0]
num_nodes: 1
precision: 32
logger:
- class_path: CSVLogger
init_args:
save_dir: logs
name: experiment
callbacks:
- class_path: RichProgressBar
init_args:
leave: True
- class_path: ModelCheckpoint
init_args:
every_n_epochs: 10
save_on_train_epoch_end: True
save_top_k: -1
- class_path: LearningRateMonitor
init_args:
logging_interval: step
max_epochs: 50
check_val_every_n_epoch: 10
num_sanity_val_steps: 0
log_every_n_steps: 1
deterministic: True
inference_mode: True
use_distributed_sampler: True
sync_batchnorm: True
model:
lr_g: 0.00016 # Learning rate for the generator
lr_d: 0.0001 # Learning rate for the discriminator
disc_grad_penalty_freq: 10 # Frequency of the gradient penalty calculation for the discriminator
disc_grad_penalty_weight: 0.5 # Gradient penalty weight for the discriminator
lambda_rec_loss: 0.5 # Reconstruction loss weight
optim_betas: [0.5, 0.9] # Adam optimizer betas
eval_mask: True # Whether to use mask during test
eval_subject: True # Whether to evaluate subject-wise during test (`subject_ids` should be provided in the test dataset)
diffusion_params:
n_steps: 10 # Number of diffusion steps
beta_start: 0.1 # Beta start value of the diffusion process
beta_end: 3.0 # Beta end value of the diffusion process
gamma: 1 # Gamma value that controls noise in the end-point of the bridge
n_recursions: 2 # Max number of recursions (R)
consistency_threshold: 0.01 # Self-consistency threshold for the recursive step
generator_params:
self_recursion: True # Whether to use self-consistent recursion
image_size: ${data.image_size} # Image size
z_emb_dim: 256 # Dimension of the latent embedding
ch_mult: [1, 1, 2, 2, 4, 4] # Channel multipliers for each resolution
num_res_blocks: 2 # Number of residual blocks
attn_resolutions: [16] # Resolutions to apply attention
dropout: 0.0 # Dropout rate
resamp_with_conv: True # Whether to use convolutional upsampling
conditional: True # Whether to use condition on time embedding
fir: True # Whether to use FIR filters
fir_kernel: [1, 3, 3, 1] # FIR filter kernel
skip_rescale: True # Whether to skip rescaling the skip connection
resblock_type: biggan # Type of the residual block
progressive: none # Whether to use progressive training
progressive_input: residual # Type of the input for the progressive training
embedding_type: positional # Embedding type
combine_method: sum # Method to combine the skip connection
fourier_scale: 16 # Fourier scale
nf: 64 # Number of filters
num_channels: 2 # Number of channels in the input
nz: 100 # Number of latent dimensions
n_mlp: 3 # Number of MLP layers
centered: True # Whether to center the input
not_use_tanh: False # Whether to use tanh activation
discriminator_params:
nc: 2 # Number of channels in the input (x_t, x_{t-1})
ngf: 32 # Number of generator filters
t_emb_dim: 256 # Dimension of the temporal embedding
data:
train_batch_size: 4
val_batch_size: 4
test_batch_size: 32
dataset_dir: ./dataset
source_modality: T1
target_modality: T2
dataset_class: NumpyDataset # Dataset class name that can be customized in datasets.py
image_size: 256 # Image size used in padding
padding: True # Whether to pad the input
norm: True # Whether to normalize the input
num_workers: 32
ckpt_path: