forked from wolny/pytorch-3dunet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_config_ce.yaml
42 lines (42 loc) · 1.35 KB
/
test_config_ce.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# path to the checkpoint file containing the model
model_path: 3dunet/best_checkpoint.pytorch
# Should the patches be stored in memory or written directly to H5 during prediction.
# Set to True if you have enough RAM cause it's way faster
store_predictions_in_memory: True
# model configuration
model:
# model class
name: UNet3D
# number of input channels to the model
in_channels: 1
# number of output channels
out_channels: 2
# determines the order of operators in a single layer (crg - Conv3d+ReLU+GroupNorm)
layer_order: crg
# feature maps scale factor
f_maps: 32
# basic module
basic_module: DoubleConv
# number of groups in the groupnorm
num_groups: 8
# apply element-wise nn.Sigmoid after the final 1x1 convolution, otherwise apply nn.Softmax
final_sigmoid: false
# specify the test datasets
datasets:
# patch size given to the network (adapt to fit in your GPU mem)
patch: [64, 128, 128]
# stride between patches (make sure the the patches overlap in order to get smoother prediction maps)
stride: [32, 100, 100]
# path to the raw data within the H5
raw_internal_path: raw
# how many subprocesses to use for data loading
num_workers: 8
# paths to the datasets
test_path:
- 'resources/random_label3D.h5'
transformer:
test:
raw:
- name: Normalize
- name: ToTensor
expand_dims: true