-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
119 lines (95 loc) · 4.54 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import torch
from torch.utils.tensorboard import SummaryWriter
import config
from lib.dataset.codd import CODDAggDataset
from lib.dataset.kitti import KITTIOdometryDataset
from lib.dataset.datatransforms import Compose, VoxelSampling, RandomRotationTransform
from lib.models.fastreg import FastReg
from lib.utils.evaluation import registrationMetrics
def executeEpoch(model, loader, opt, sched, e, sw, mode='train'):
assert mode == 'train' or mode =='val', 'mode should be train or val'
if mode == 'train':
model.train()
else:
model.eval()
lE, lpE, lnE = 0, 0, 0
rotE, transE = 0, 0
maxInliersE, actualInliersE = 0, 0
for b, (pb,RtGT) in enumerate(loader):
pb = pb.cuda()
RtGT = RtGT.cuda()
if mode == 'train':
R, t, loss_pos, loss_neg, loss, maxInliers, actualInliers = model(pb, RtGT)
else:
with torch.no_grad():
R, t, loss_pos, loss_neg, loss, maxInliers, actualInliers = model(pb, RtGT)
lE += loss.detach().item()
lpE += loss_pos.detach().item()
lnE += loss_neg.detach().item()
rotErr, transErr = registrationMetrics(RtGT, R, t)
rotErr = rotErr.detach().median().item()
transErr = transErr.detach().median().item()
rotE += rotErr
transE += transErr
maxInliersE += maxInliers.item()
actualInliersE += actualInliers.item()
if mode == 'train':
#optimise model
loss.backward()
opt.step()
opt.zero_grad()
print(f'E {e}/B {b}. Loss {loss.detach().item():.4f}. PLoss {loss_pos.detach().item():.4f}. NLoss {loss_neg.detach().item():.4f}. MRE {rotErr:.2f}, MTE {transErr:.2f}. Inliers {actualInliers.item()}/{maxInliers.item()}')
#stats
batches = len(loader)
lE /= batches
lpE /= batches
lnE /= batches
rotE /= batches
transE /= batches
maxInliersE /= batches
actualInliersE /= batches
print(f'{mode} {e}. Loss {lE:.4f}. PLoss {lpE:.4f}. NLoss {lnE:.4f}. MRE {rotE:.2f}. MTE {transE:.2f}. MInliers {actualInliersE:.1f}/{maxInliersE:.1f}')
#update tensorboard
sw.add_scalar(f'{mode}/loss', lE, e)
sw.add_scalar(f'{mode}/loss_pos', lpE, e)
sw.add_scalar(f'{mode}/loss_neg', lnE, e)
sw.add_scalar(f'{mode}/rot_err', rotE, e)
sw.add_scalar(f'{mode}/trans_err', transE, e)
sw.add_scalar(f'{mode}/trans_err', transE, e)
sw.add_scalar(f'{mode}/maxInliers', maxInliersE, e)
sw.add_scalar(f'{mode}/actualInliers', actualInliersE, e)
#update scheduler
if mode == 'train':
sched.step()
def train(args):
if args.dataset == 'codd':
trainDataset = CODDAggDataset(config.CODD_PATH, mode='train', transform=Compose([VoxelSampling(0.3), RandomRotationTransform(rsig=40)]))
valDataset = CODDAggDataset(config.CODD_PATH, mode='val', transform=VoxelSampling(0.3))
elif args.dataset == 'kitti':
trainDataset = KITTIOdometryDataset(config.KITTI_PATH, mode='train', transform=Compose([VoxelSampling(0.3), RandomRotationTransform(rsig=40)]))
valDataset = KITTIOdometryDataset(config.KITTI_PATH, mode='val', transform=VoxelSampling(0.3))
trainLoader = torch.utils.data.DataLoader(trainDataset, batch_size=config.batch_size, pin_memory=True, drop_last=True, num_workers=config.batch_size, shuffle=True)
valLoader= torch.utils.data.DataLoader(valDataset, batch_size=config.batch_size, pin_memory=True, drop_last=True, num_workers=config.batch_size)
model = FastReg(config.T).cuda()
opt = torch.optim.Adam(model.parameters(), lr=1e-1, eps=1e-4)
sched = torch.optim.lr_scheduler.StepLR(opt, step_size=5, gamma=0.5)
if args.checkpoint:
model.load_state_dict(torch.load(args.checkpoint))
model = model.cuda()
expPath = 'runs/'
writer = SummaryWriter(expPath)
for e in range(config.epochs):
executeEpoch(model, trainLoader, opt, sched, e, writer, mode='train')
if (e + 1) % config.val_period == 0:
#run validation
executeEpoch(model, valLoader, opt, sched, e, writer, mode='val')
#saves model
torch.save(model.state_dict(), f'{expPath}/model{e}.pth')
writer.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Trains FastReg registration model')
parser.add_argument('dataset', choices=('codd','kitti'), help='dataset used for evaluation')
parser.add_argument('--checkpoint', type=str, help='path to model checkpoint (continue training)')
args = parser.parse_args()
train(args)