-
Notifications
You must be signed in to change notification settings - Fork 88
/
Copy patheval.py
135 lines (108 loc) · 4.21 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import json
import pprint
import random
import time
import torch
import torch.multiprocessing as mp
from models.nn.resnet import Resnet
from data.preprocess import Dataset
from importlib import import_module
class Eval(object):
# tokens
STOP_TOKEN = "<<stop>>"
SEQ_TOKEN = "<<seg>>"
TERMINAL_TOKENS = [STOP_TOKEN, SEQ_TOKEN]
def __init__(self, args, manager):
# args and manager
self.args = args
self.manager = manager
# load splits
with open(self.args.splits) as f:
self.splits = json.load(f)
pprint.pprint({k: len(v) for k, v in self.splits.items()})
# load model
print("Loading: ", self.args.model_path)
M = import_module(self.args.model)
self.model, optimizer = M.Module.load(self.args.model_path)
self.model.share_memory()
self.model.eval()
self.model.test_mode = True
# updated args
self.model.args.dout = self.args.model_path.replace(self.args.model_path.split('/')[-1], '')
self.model.args.data = self.args.data if self.args.data else self.model.args.data
# preprocess and save
if args.preprocess:
print("\nPreprocessing dataset and saving to %s folders ... This is will take a while. Do this once as required:" % self.model.args.pp_folder)
self.model.args.fast_epoch = self.args.fast_epoch
dataset = Dataset(self.model.args, self.model.vocab)
dataset.preprocess_splits(self.splits)
# load resnet
args.visual_model = 'resnet18'
self.resnet = Resnet(args, eval=True, share_memory=True, use_conv_feat=True)
# gpu
if self.args.gpu:
self.model = self.model.to(torch.device('cuda'))
# success and failure lists
self.create_stats()
# set random seed for shuffling
random.seed(int(time.time()))
def queue_tasks(self):
'''
create queue of trajectories to be evaluated
'''
task_queue = self.manager.Queue()
files = self.splits[self.args.eval_split]
# debugging: fast epoch
if self.args.fast_epoch:
files = files[:16]
if self.args.shuffle:
random.shuffle(files)
for traj in files:
task_queue.put(traj)
return task_queue
def spawn_threads(self):
'''
spawn multiple threads to run eval in parallel
'''
task_queue = self.queue_tasks()
# start threads
threads = []
lock = self.manager.Lock()
for n in range(self.args.num_threads):
thread = mp.Process(target=self.run, args=(self.model, self.resnet, task_queue, self.args, lock,
self.successes, self.failures, self.results))
thread.start()
threads.append(thread)
for t in threads:
t.join()
# save
self.save_results()
@classmethod
def setup_scene(cls, env, traj_data, r_idx, args, reward_type='dense'):
'''
intialize the scene and agent from the task info
'''
# scene setup
scene_num = traj_data['scene']['scene_num']
object_poses = traj_data['scene']['object_poses']
dirty_and_empty = traj_data['scene']['dirty_and_empty']
object_toggles = traj_data['scene']['object_toggles']
scene_name = 'FloorPlan%d' % scene_num
env.reset(scene_name)
env.restore_scene(object_poses, object_toggles, dirty_and_empty)
# initialize to start position
env.step(dict(traj_data['scene']['init_action']))
# print goal instr
print("Task: %s" % (traj_data['turk_annotations']['anns'][r_idx]['task_desc']))
# setup task for reward
env.set_task(traj_data, args, reward_type=reward_type)
@classmethod
def run(cls, model, resnet, task_queue, args, lock, successes, failures):
raise NotImplementedError()
@classmethod
def evaluate(cls, env, model, r_idx, resnet, traj_data, args, lock, successes, failures):
raise NotImplementedError()
def save_results(self):
raise NotImplementedError()
def create_stats(self):
raise NotImplementedError()