-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathrun_cat_vkitti.py
74 lines (60 loc) · 2.44 KB
/
run_cat_vkitti.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import os
from torch.utils.data import ConcatDataset
from cat_net.options import Options
from cat_net.models import CATModel
from cat_net.datasets import vkitti
from cat_net import experiment
### COMMAND LINE ARGUMENTS ###
parser = argparse.ArgumentParser()
parser.add_argument('stage', type=str, choices=['train', 'test', 'both'])
parser.add_argument('--resume', action='store_true')
args = parser.parse_args()
resume_from_epoch = 'latest' if args.resume else None
### CONFIGURATION ###
opts = Options()
opts.data_dir = '/media/m2-drive/datasets/virtual-kitti/raw'
opts.results_dir = '/media/raid5-array/experiments/cat-net/virtual-kitti'
opts.down_levels = 7
opts.innermost_kernel_size = (3, 4)
### SET TRAINING, VALIDATION AND TEST SETS ###
seqs = ['0001', '0002', '0006', '0018', '0020']
conds = ['clone', 'morning', 'overcast', 'sunset']
canonical = conds[2]
for test_seq in seqs:
train_seqs = seqs.copy()
train_seqs.remove(test_seq)
val_seqs = [test_seq]
val_conds = [conds[0]]
train_data = []
for train_seq in train_seqs:
for cond in conds:
print('Train {}: {} --> {}'.format(train_seq, cond, canonical))
data = vkitti.TorchDataset(
opts, train_seq, cond, canonical, opts.random_crop)
train_data.append(data)
train_data = ConcatDataset(train_data)
val_data = []
for val_seq in val_seqs:
for cond in val_conds:
print('Val {}: {} --> {}'.format(val_seq, cond, canonical))
data = vkitti.TorchDataset(
opts, val_seq, cond, canonical, False)
val_data.append(data)
val_data = ConcatDataset(val_data)
### TRAIN / TEST ###
opts.experiment_name = '{}-test'.format(test_seq)
model = CATModel(opts)
if args.stage == 'train' or args.stage == 'both':
print(opts)
opts.save_txt()
experiment.train(opts, model, train_data, val_data,
opts.train_epochs, resume_from_epoch=resume_from_epoch)
if args.stage == 'test' or args.stage == 'both':
for cond in conds:
print('Test {}: {} --> {}'.format(test_seq, cond, canonical))
expdir = os.path.join(opts.experiment_name, '{}-test'.format(cond))
test_data = vkitti.TorchDataset(
opts, test_seq, cond, canonical, False)
experiment.test(opts, model, test_data, expdir=expdir,
save_loss=True, save_images=True)