forked from nkullman/ridehailing_package
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimplemain.py
48 lines (31 loc) · 1.09 KB
/
simplemain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import matplotlib.pyplot as plt
from pyhailing import RidehailEnv
def main(render:bool=False):
env_config = RidehailEnv.DIMACS_CONFIGS.SUI
env_config["nickname"] = "testing"
env = RidehailEnv(**env_config)
all_eps_rewards = []
for episode in range(RidehailEnv.DIMACS_NUM_EVAL_EPISODES):
obs = env.reset()
terminal = False
reward = 0
if render:
rgb = env.render()
plt.imshow(rgb)
plt.show()
while not terminal:
action = env.get_noop_action()
# action = env.get_random_action()
next_obs, new_rwd, terminal, _ = env.step(action)
reward += new_rwd
obs = next_obs
if render:
rgb = env.render()
plt.imshow(rgb)
plt.show()
print(f"Episode {episode} complete. Reward: {reward}")
all_eps_rewards.append(reward)
mean_reward = sum(all_eps_rewards)/len(all_eps_rewards)
print(f"All episodes complete. Average reward: {mean_reward}")
if __name__ == "__main__":
main()