-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpart_segmentation_polarmae_fft.yml
82 lines (82 loc) · 2.91 KB
/
part_segmentation_polarmae_fft.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
seed_everything: 0
trainer:
sync_batchnorm: false # this doesn't do anything as we use custom masked batchnorm
gradient_clip_val: 3
devices: 1
num_nodes: 1
strategy: ddp
max_steps: 100000
precision: bf16-mixed
val_check_interval: 500
check_val_every_n_epoch: null
# optional wandb logging
# logger:
# - class_path: pytorch_lightning.loggers.WandbLogger
# init_args:
# project: Part-Segmentation-LArNet-5voxel
# name: ViTS/5-10k-Multitask-FFT
model:
pretrained_ckpt_path: null # ADD PATH TO CKPT HERE
encoder_freeze: false
num_classes: 4
encoder:
class_path: polarmae.layers.encoder.TransformerEncoder
init_args:
num_channels: 4
arch: vit_small
voxel_size: 5
masking_ratio: 0.6
masking_type: rand
tokenizer_kwargs: # override tokenizer kwargs if wanted. for example, you can override the group radius from 5 voxels to below:
group_radius: ${eval:'${model.encoder.init_args.voxel_size} * ${model.transformation_scale_factor}'} # equivalent to 5 voxel radius in normalized coords
apply_relative_position_bias: false
transformer_kwargs:
postnorm: false
add_pos_at_every_layer: true
drop_rate: 0.0
attn_drop_rate: 0.05
drop_path_rate: 0.25
seg_decoder: null # optional transformer-based part seg decoder
# class_path: larnet.layers.decoder.TransformerDecoder
# init_args:
# arch: ${...encoder.init_args.arch}
# transformer_kwargs:
# postnorm: true
# depth: 4
# add_pos_at_every_layer: ${....encoder.init_args.transformer_kwargs.add_pos_at_every_layer}
# drop_rate: ${....encoder.init_args.transformer_kwargs.drop_rate}
# attn_drop_rate: ${....encoder.init_args.transformer_kwargs.attn_drop_rate}
# drop_path_rate: ${....encoder.init_args.transformer_kwargs.drop_path_rate}
condition_global_features: true
seg_head_fetch_layers: [3, 7, 11]
seg_head_dim: 512
seg_head_dropout: 0.5
loss_func: focal
# learning parameters
learning_rate: 0.0001
optimizer_adamw_weight_decay: 0.05
lr_scheduler_linear_warmup_epochs: 12500
lr_scheduler_linear_warmup_start_lr: 8.6e-6
lr_scheduler_cosine_eta_min: ${.lr_scheduler_linear_warmup_start_lr}
lr_scheduler_stepping: step # or 'epoch'
# other
train_transformations:
- "center_and_scale"
- "rotate"
val_transformations:
- "center_and_scale"
transformation_center: [384, 384, 384] # [768 / 2]*3. will be subtracted from the point cloud coordinates.
transformation_scale_factor: ${eval:'1/ (${.transformation_center[0]} * (3**0.5))'} # 1 / (768 * sqrt(3) / 2)
data:
class_path: polarmae.datasets.PILArNetDataModule
init_args:
data_path: null # ADD PATH TO DATA HERE
batch_size: 32
num_workers: 4
dataset_kwargs:
energy_threshold: 0.13
remove_low_energy_scatters: true
emin: 1.0e-2
emax: 20.0
maxlen: 10000
min_points: 1024