-
Notifications
You must be signed in to change notification settings - Fork 12
/
evaluator.py
127 lines (108 loc) · 5.02 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gfootball.env as football_env
import time, pprint, importlib, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import torch.multiprocessing as mp
from os import listdir
from os.path import isfile, join
from datetime import datetime, timedelta
def state_to_tensor(state_dict, h_in):
player_state = torch.from_numpy(state_dict["player"]).float().unsqueeze(0).unsqueeze(0)
ball_state = torch.from_numpy(state_dict["ball"]).float().unsqueeze(0).unsqueeze(0)
left_team_state = torch.from_numpy(state_dict["left_team"]).float().unsqueeze(0).unsqueeze(0)
left_closest_state = torch.from_numpy(state_dict["left_closest"]).float().unsqueeze(0).unsqueeze(0)
right_team_state = torch.from_numpy(state_dict["right_team"]).float().unsqueeze(0).unsqueeze(0)
right_closest_state = torch.from_numpy(state_dict["right_closest"]).float().unsqueeze(0).unsqueeze(0)
avail = torch.from_numpy(state_dict["avail"]).float().unsqueeze(0).unsqueeze(0)
state_dict_tensor = {
"player" : player_state,
"ball" : ball_state,
"left_team" : left_team_state,
"left_closest" : left_closest_state,
"right_team" : right_team_state,
"right_closest" : right_closest_state,
"avail" : avail,
"hidden" : h_in
}
return state_dict_tensor
def get_action(a_prob, m_prob):
a = Categorical(a_prob).sample().item()
m, need_m = 0, 0
prob_selected_a = a_prob[0][0][a].item()
prob_selected_m = 0
if a==0:
real_action = a
prob = prob_selected_a
elif a==1:
m = Categorical(m_prob).sample().item()
need_m = 1
real_action = m + 1
prob_selected_m = m_prob[0][0][m].item()
prob = prob_selected_a* prob_selected_m
else:
real_action = a + 7
prob = prob_selected_a
assert prob != 0, 'prob 0 ERROR!!!! a : {}, m:{} {}, {}'.format(a,m,prob_selected_a,prob_selected_m)
return real_action, a, m, need_m, prob, prob_selected_a, prob_selected_m
def evaluator(center_model, signal_queue, summary_queue, arg_dict):
print("Evaluator process started")
fe_module = importlib.import_module("encoders." + arg_dict["encoder"])
rewarder = importlib.import_module("rewarders." + arg_dict["rewarder"])
imported_model = importlib.import_module("models." + arg_dict["model"])
fe = fe_module.FeatureEncoder()
model = center_model
env = football_env.create_environment(env_name=arg_dict["env_evaluation"], representation="raw", stacked=False, logdir='/tmp/football', \
write_goal_dumps=False, write_full_episode_dumps=False, render=False)
n_epi = 0
while True: # episode loop
env.reset()
done = False
steps, score, tot_reward, win = 0, 0, 0, 0
n_epi += 1
h_out = (torch.zeros([1, 1, arg_dict["lstm_size"]], dtype=torch.float),
torch.zeros([1, 1, arg_dict["lstm_size"]], dtype=torch.float))
loop_t, forward_t, wait_t = 0.0, 0.0, 0.0
obs = env.observation()
while not done: # step loop
init_t = time.time()
is_stopped = False
while signal_queue.qsize() > 0:
time.sleep(0.02)
is_stopped = True
if is_stopped:
#model.load_state_dict(center_model.state_dict())
pass
wait_t += time.time() - init_t
h_in = h_out
state_dict = fe.encode(obs[0])
state_dict_tensor = state_to_tensor(state_dict, h_in)
t1 = time.time()
with torch.no_grad():
a_prob, m_prob, _, h_out = model(state_dict_tensor)
forward_t += time.time()-t1
real_action, a, m, need_m, prob, prob_selected_a, prob_selected_m = get_action(a_prob, m_prob)
prev_obs = obs
obs, rew, done, info = env.step(real_action)
fin_r = rewarder.calc_reward(rew, prev_obs[0], obs[0])
state_prime_dict = fe.encode(obs[0])
(h1_in, h2_in) = h_in
(h1_out, h2_out) = h_out
state_dict["hidden"] = (h1_in.numpy(), h2_in.numpy())
state_prime_dict["hidden"] = (h1_out.numpy(), h2_out.numpy())
transition = (state_dict, a, m, fin_r, state_prime_dict, prob, done, need_m)
steps += 1
score += rew
tot_reward += fin_r
if arg_dict['print_mode']:
print_status(steps,a,m,prob_selected_a,prob_selected_m,prev_obs,obs,fin_r,tot_reward)
loop_t += time.time()-init_t
if done:
if score > 0:
win = 1
print("score",score,"total reward",tot_reward)
summary_data = (win, score, tot_reward, steps, arg_dict['env_evaluation'], loop_t/steps, forward_t/steps, wait_t/steps)
summary_queue.put(summary_data)