forked from wolny/pytorch-3dunet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
executable file
·282 lines (226 loc) · 11.9 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import os
import h5py
import numpy as np
import torch
from datasets.hdf5 import get_test_loaders
from unet3d import utils
from unet3d.config import load_config
from unet3d.model import get_model
logger = utils.get_logger('UNet3DPredictor')
def predict_in_memory(model, data_loader, output_file, config):
"""
Return prediction masks by applying the model on the given dataset
Args:
model (Unet3D): trained 3D UNet model used for prediction
data_loader (torch.utils.data.DataLoader): input data loader
output_file (str): path to the output H5 file
config (dict): global config dict
Returns:
prediction_maps (numpy array): prediction masks for given dataset
"""
def _volume_shape(dataset):
# TODO: support multiple internal datasets
raw = dataset.raws[0]
if raw.ndim == 3:
return raw.shape
else:
return raw.shape[1:]
out_channels = config['model'].get('out_channels')
if out_channels is None:
out_channels = config['model']['dt_out_channels']
prediction_channel = config.get('prediction_channel', None)
if prediction_channel is not None:
logger.info(f"Using only channel '{prediction_channel}' from the network output")
device = config['device']
output_heads = config['model'].get('output_heads', 1)
logger.info(f'Running prediction on {len(data_loader)} patches...')
# dimensionality of the the output (CxDxHxW)
volume_shape = _volume_shape(data_loader.dataset)
if prediction_channel is None:
prediction_maps_shape = (out_channels,) + volume_shape
else:
# single channel prediction map
prediction_maps_shape = (1,) + volume_shape
logger.info(f'The shape of the output prediction maps (CDHW): {prediction_maps_shape}')
# initialize the output prediction arrays
prediction_maps = [np.zeros(prediction_maps_shape, dtype='float32') for _ in range(output_heads)]
# initialize normalization mask in order to average out probabilities of overlapping patches
normalization_masks = [np.zeros(prediction_maps_shape, dtype='float32') for _ in range(output_heads)]
# Sets the module in evaluation mode explicitly, otherwise the final Softmax/Sigmoid won't be applied!
model.eval()
# Run predictions on the entire input dataset
with torch.no_grad():
for patch, index in data_loader:
logger.info(f'Predicting slice:{index}')
# save patch index: (C,D,H,W)
if prediction_channel is None:
channel_slice = slice(0, out_channels)
else:
channel_slice = slice(0, 1)
index = (channel_slice,) + tuple(index)
# send patch to device
patch = patch.to(device)
# forward pass
predictions = model(patch)
# wrap predictions into a list if there is only one output head from the network
if output_heads == 1:
predictions = [predictions]
for prediction, prediction_map, normalization_mask in zip(predictions, prediction_maps,
normalization_masks):
# squeeze batch dimension and convert back to numpy array
prediction = prediction.squeeze(dim=0).cpu().numpy()
if prediction_channel is not None:
# use only the 'prediction_channel'
logger.info(f"Using channel '{prediction_channel}'...")
prediction = np.expand_dims(prediction[prediction_channel], axis=0)
# unpad in order to avoid block artifacts in the output probability maps
u_prediction, u_index = utils.unpad(prediction, index, volume_shape)
# accumulate probabilities into the output prediction array
prediction_map[u_index] += u_prediction
# count voxel visits for normalization
normalization_mask[u_index] += 1
# save probability maps
prediction_datasets = _get_dataset_names(config, output_heads, prefix='predictions')
with h5py.File(output_file, 'w') as f:
for prediction_map, normalization_mask, prediction_dataset in zip(prediction_maps, normalization_masks,
prediction_datasets):
prediction_map = prediction_map / normalization_mask
logger.info(f'Saving predictions to: {output_file}/{prediction_dataset}...')
f.create_dataset(prediction_dataset, data=prediction_map, compression="gzip")
def predict(model, data_loader, output_file, config):
"""
Return prediction masks by applying the model on the given dataset.
The predictions are saved in the output H5 file on a patch by patch basis.
If your dataset fits into memory use predict_in_memory() which is much faster.
Args:
model (Unet3D): trained 3D UNet model used for prediction
data_loader (torch.utils.data.DataLoader): input data loader
output_file (str): path to the output H5 file
config (dict): global config dict
"""
def _volume_shape(dataset):
# TODO: support multiple internal datasets
raw = dataset.raws[0]
if raw.ndim == 3:
return raw.shape
else:
return raw.shape[1:]
out_channels = config['model'].get('out_channels')
if out_channels is None:
out_channels = config['model']['dt_out_channels']
prediction_channel = config.get('prediction_channel', None)
if prediction_channel is not None:
logger.info(f"Using only channel '{prediction_channel}' from the network output")
device = config['device']
output_heads = config['model'].get('output_heads', 1)
logger.info(f'Running prediction on {len(data_loader)} patches...')
# dimensionality of the the output (CxDxHxW)
volume_shape = _volume_shape(data_loader.dataset)
if prediction_channel is None:
prediction_maps_shape = (out_channels,) + volume_shape
else:
# single channel prediction map
prediction_maps_shape = (1,) + volume_shape
logger.info(f'The shape of the output prediction maps (CDHW): {prediction_maps_shape}')
with h5py.File(output_file, 'w') as f:
# allocate datasets for probability maps
prediction_datasets = _get_dataset_names(config, output_heads, prefix='predictions')
prediction_maps = [
f.create_dataset(dataset_name, shape=prediction_maps_shape, dtype='float32', chunks=True,
compression='gzip')
for dataset_name in prediction_datasets]
# allocate datasets for normalization masks
normalization_datasets = _get_dataset_names(config, output_heads, prefix='normalization')
normalization_masks = [
f.create_dataset(dataset_name, shape=prediction_maps_shape, dtype='uint8', chunks=True,
compression='gzip')
for dataset_name in normalization_datasets]
# Sets the module in evaluation mode explicitly, otherwise the final Softmax/Sigmoid won't be applied!
model.eval()
# Run predictions on the entire input dataset
with torch.no_grad():
for patch, index in data_loader:
logger.info(f'Predicting slice:{index}')
# save patch index: (C,D,H,W)
if prediction_channel is None:
channel_slice = slice(0, out_channels)
else:
channel_slice = slice(0, 1)
index = (channel_slice,) + tuple(index)
# send patch to device
patch = patch.to(device)
# forward pass
predictions = model(patch)
# wrap predictions into a list if there is only one output head from the network
if output_heads == 1:
predictions = [predictions]
for prediction, prediction_map, normalization_mask in zip(predictions, prediction_maps,
normalization_masks):
# squeeze batch dimension and convert back to numpy array
prediction = prediction.squeeze(dim=0).cpu().numpy()
if prediction_channel is not None:
# use only the 'prediction_channel'
logger.info(f"Using channel '{prediction_channel}'...")
prediction = np.expand_dims(prediction[prediction_channel], axis=0)
# unpad in order to avoid block artifacts in the output probability maps
u_prediction, u_index = utils.unpad(prediction, index, volume_shape)
# accumulate probabilities into the output prediction array
prediction_map[u_index] += u_prediction
# count voxel visits for normalization
normalization_mask[u_index] += 1
# normalize the prediction_maps inside the H5
for prediction_map, normalization_mask, prediction_dataset, normalization_dataset in zip(prediction_maps,
normalization_masks,
prediction_datasets,
normalization_datasets):
# TODO: iterate block by block
# split the volume into 4 parts and load each into the memory separately
logger.info(f'Normalizing {prediction_dataset}...')
z, y, x = volume_shape
mid_x = x // 2
mid_y = y // 2
prediction_map[:, :, 0:mid_y, 0:mid_x] /= normalization_mask[:, :, 0:mid_y, 0:mid_x]
prediction_map[:, :, mid_y:, 0:mid_x] /= normalization_mask[:, :, mid_y:, 0:mid_x]
prediction_map[:, :, 0:mid_y, mid_x:] /= normalization_mask[:, :, 0:mid_y, mid_x:]
prediction_map[:, :, mid_y:, mid_x:] /= normalization_mask[:, :, mid_y:, mid_x:]
logger.info(f'Deleting {normalization_dataset}...')
del f[normalization_dataset]
def _get_output_file(dataset, suffix='_predictions'):
return f'{os.path.splitext(dataset.file_path)[0]}{suffix}.h5'
def _get_dataset_names(config, number_of_datasets, prefix='predictions'):
dataset_names = config.get('dest_dataset_name')
if dataset_names is not None:
if isinstance(dataset_names, str):
return [dataset_names]
else:
return dataset_names
else:
if number_of_datasets == 1:
return [prefix]
else:
return [f'{prefix}{i}' for i in range(number_of_datasets)]
def main():
# Load configuration
config = load_config()
# Create the model
model = get_model(config)
# Load model state
model_path = config['model_path']
logger.info(f'Loading model from {model_path}...')
utils.load_checkpoint(model_path, model)
logger.info(f"Sending the model to '{config['device']}'")
model = model.to(config['device'])
logger.info('Loading HDF5 datasets...')
store_predictions_in_memory = config.get('store_predictions_in_memory', True)
if store_predictions_in_memory:
logger.info('Predictions will be stored in memory. Make sure you have enough RAM for you dataset.')
for test_loader in get_test_loaders(config):
logger.info(f"Processing '{test_loader.dataset.file_path}'...")
output_file = _get_output_file(test_loader.dataset)
# run the model prediction on the entire dataset and save to the 'output_file' H5
if store_predictions_in_memory:
predict_in_memory(model, test_loader, output_file, config)
else:
predict(model, test_loader, output_file, config)
if __name__ == '__main__':
main()