-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
129 lines (103 loc) · 4.58 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import glob
import sys
import random
import time
import torch
import numpy as np
# import scipy.ndimage
from argparse import ArgumentParser
from torchsummary import summary
os.environ['VXM_BACKEND'] = 'pytorch'
import voxelmorph as vxm
def train(datadir,
model_dir,
load_model,
gpu,
initial_epoch,
epochs,
steps_per_epoch,
batch_size,
atlas=False,
bidir=False):
train_vol_names = glob.glob(os.path.join(datadir, '*.nii.gz'))
random.shuffle(train_vol_names) # shuffle volume list
assert len(train_vol_names) > 0, 'Could not find any training data'
# no need to append an extra feature axis if data is multichannel
add_feat_axis = True
generator = vxm.generators.scan_to_scan(train_vol_names, batch_size=batch_size, bidir=bidir,
add_feat_axis=add_feat_axis)
# extract shape from sampled input
inshape = next(generator)[0][0].shape[1:-1]
os.makedirs(model_dir, exist_ok=True)
# prepare odel folder
if not os.path.isdir(model_dir):
os.mkdir(model_dir)
device = 'cuda'
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
torch.backends.cudnn.deterministic = True
# prepare the model
model = vxm.pivit.pivit(inshape)
model.to(device)
summary(model)
if load_model != False:
print('loading', load_model)
best_model = torch.load(load_model)
model.load_state_dict(best_model)
# set optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# prepare losses
Losses = [vxm.losses.NCC().loss, vxm.losses.Grad_2('l2').loss]
Weights = [1.0, 1.0]
# training/validate loops
for epoch in range(initial_epoch, epochs):
start_time = time.time()
# training
model.train()
train_losses = []
train_total_loss = []
for step in range(steps_per_epoch):
# generate inputs (and true outputs) and convert them to tensors
inputs, labels = next(generator)
# inputs = [torch.from_numpy(d).to(device).float() for d in inputs]
# labels = [torch.from_numpy(d).to(device).float() for d in labels]
inputs = [torch.from_numpy(d).to(device).float().permute(0, 4, 1, 2, 3) for d in inputs] # 其实包括了俩
labels = [torch.from_numpy(d).to(device).float().permute(0, 4, 1, 2, 3) for d in labels] # 一个
source = inputs[0]
target = inputs[1]
# run inputs through the model to produce a warped image and flow field
pred = model(source, target)
# calculate total loss
loss = 0
loss_list = []
for i, Loss in enumerate(Losses):
curr_loss = Loss(pred[i], target) * Weights[i]
loss_list.append(curr_loss.item())
loss += curr_loss
train_losses.append(loss_list)
train_total_loss.append(loss.item())
# backpropagate and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print epoch info
epoch_info = 'Epoch %d/%d' % (epoch + 1, epochs)
time_info = 'Total %.2f sec' % (time.time() - start_time)
train_losses = ', '.join(['%.4f' % f for f in np.mean(train_losses, axis=0)])
train_loss_info = 'Train loss: %.4f (%s)' % (np.mean(train_total_loss), train_losses)
print(' - '.join((epoch_info, time_info, train_loss_info)), flush=True)
# save model checkpoint
if (epoch+1) % 10 == 0:
torch.save(model.state_dict(), os.path.join(model_dir, '%04d.pt' % (epoch+1)))
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('datadir', help='base data directory')
parser.add_argument('--model-dir', default='models', help='model output directory (default: models)')
parser.add_argument('--load-model', default=False, help='optional model file to initialize with')
parser.add_argument('--gpu', default='0', help='GPU ID number(s), comma-separated (default: 0)')
parser.add_argument('--initial-epoch', type=int, default=0, help='initial epoch number (default: 0)')
parser.add_argument('--epochs', type=int, default=1000, help='number of training epochs (default: 1500)')
parser.add_argument('--steps-per-epoch', type=int, default=100, help='frequency of model saves (default: 100)')
parser.add_argument('--batch-size', type=int, default=1, help='batch size (default: 1)')
args = parser.parse_args()
train(**vars(args))