forked from SZUHvern/TMI_multi-contrast-registration
-
Notifications
You must be signed in to change notification settings - Fork 0
/
affine_train.py
105 lines (85 loc) · 4.22 KB
/
affine_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
from model import *
from keras.preprocessing.image import ImageDataGenerator
import time
import h5py
import math
from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard
from keras.optimizers import Adam
import shutil
from losses import *
def generator_train_stn(input, label, batch_size=32):
# gen_image = ImageDataGenerator(width_shift_range=0.1, height_shift_range=0.1, zoom_range=0.2, rotation_range=5, fill_mode='constant')
gen_image = ImageDataGenerator(width_shift_range=0.1, height_shift_range=0.1, rotation_range=5, zoom_range=0.1,
featurewise_center=True, featurewise_std_normalization=True,
horizontal_flip=True, fill_mode='constant')
gen_mask = ImageDataGenerator(width_shift_range=0.1, height_shift_range=0.1, rotation_range=5, zoom_range=0.1,
featurewise_center=True, featurewise_std_normalization=True,
horizontal_flip=True, fill_mode='constant')
train_image = gen_image.flow(input, batch_size=batch_size, shuffle=True, seed=1)
train_mask = gen_mask.flow(label, batch_size=batch_size, shuffle=True, seed=1)
while True:
next_train = next(train_image)
next_mask = next(train_mask)
yield [next_train, next_mask], [next_mask]
def generator_val_stn(input, label, batch_size=32):
gen_image2 = ImageDataGenerator(featurewise_center=True, featurewise_std_normalization=True)
gen_mask2 = ImageDataGenerator(featurewise_center=True, featurewise_std_normalization=True)
train_image = gen_image2.flow(input, batch_size=batch_size, shuffle=False)
train_mask = gen_mask2.flow(label, batch_size=batch_size, shuffle=False)
while True:
next_train = next(train_image)
next_mask = next(train_mask)
yield [next_train, next_mask], [next_mask]
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
path_h5_save = './ori_data/deal/'
output_path = './model/'
load_weight = ''
mode = 'train'
batch_size = 64
lr = 1e-4
w = 16
h5_name = 'affine'
subdir_savepic = h5_name
output_path += h5_name + '/'
#
if mode == 'train':
if os.path.exists(output_path):
shutil.rmtree(output_path)
os.makedirs(output_path)
model = Affine_net(w=w, trainable=True)
model.compile(optimizer=Adam(lr=lr),
loss=design_loss().mi_clipmse,
metrics={'affine_result': dice_coef})
model.summary()
if load_weight != '':
print('loading:', load_weight)
model.load_weights(load_weight, by_name=True)
else:
print('no loading weight!')
time_start = time.time()
print(path_h5_save + 'train')
h5 = h5py.File(path_h5_save + 'train')
train_F = h5['F']
train_DWI = h5['DWI']
h5 = h5py.File(path_h5_save + 'test')
val_F = h5['F']
val_DWI = h5['DWI']
label_dwi = h5['label_dwi']
label_flair = h5['label_flair']
num_train_steps = math.floor(len(train_F) / batch_size)
num_val_steps = math.floor(len(val_F) / batch_size)
print('training data:' + str(len(train_F)) + ' validation data:' + str(len(val_F)))
print('used:', str(time.time() - time_start) + 's\n')
time_start = time.time()
train_data = generator_train_stn(train_F, train_DWI, batch_size=batch_size)
val_data = generator_val_stn(val_F, val_DWI, batch_size=batch_size)
earlystop = EarlyStopping(monitor='val_loss', patience=20, verbose=1)
tensorboard = TensorBoard(log_dir='./tensorboard/' + h5_name + '/', batch_size=batch_size)
checkpointer = ModelCheckpoint(output_path + 'epoch{epoch:03d}-{val_dice_coef:.2f}.h5',
monitor='val_loss', mode='min', verbose=1,
save_best_only=True)
history = model.fit_generator(train_data, epochs=100, initial_epoch=0, steps_per_epoch=num_train_steps, shuffle=True,
callbacks=[checkpointer, tensorboard, earlystop], validation_data=val_data,
validation_steps=num_val_steps, verbose=2)