-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
25 lines (20 loc) · 806 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from os.path import dirname, join
import trainer
from options import Options
options_handler = Options()
options = options_handler.parse()
if __name__ == "__main__":
if options.phase in ['test_tea', 'test_stu', 'train_stu']:
print(f'resume from {options.resume}')
options = options_handler.update_opt_from_json(join(dirname(options.resume), 'flags.json'), options)
tr = trainer.Trainer(options)
print(tr.opt.phase, '-->', tr.opt.runsPath)
elif options.phase in ['train_tea']:
tr = trainer.Trainer(options)
print(tr.opt.phase, '-->', tr.opt.runsPath)
if options.phase in ['train_tea']:
tr.train()
elif options.phase in ['train_stu']:
tr.train_student()
elif options.phase in ['test_tea', 'test_stu']:
tr.test()