-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutil.py
144 lines (109 loc) · 4.4 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""Useful functions for the Game.
Authors:
Gael Colas
"""
import numpy as np
import ujson as json
import threading
def input_thread(inputs_list):
"""Save the user inputs.
"""
c = input()
inputs_list.append(c)
def display_info(n_sim, highscore, commands_filename):
"""Display the current highscore and the current highscore.
Args:
'n_sim' (int): number of simulations played
'highscore' (tuple of int, (human, AI)): the best score achieved by a human and an AI
'commands_filename' (str): filename of the text file listing the commands used in the game
"""
simulation_text = "Simulations: {}\n".format(n_sim)
score_text = "Highscore Human: {}\nHighscore AI: {}\n".format(*highscore)
with open(commands_filename, "r") as commands_file:
commands_text = commands_file.read()
input_text = "\nENTER COMMAND:\n"
print(simulation_text, score_text, commands_text, input_text, sep='\n')
def handle_user_command(gym):
"""Handle user commands.
Args:
'gym' (Gym): agent training gym
Remarks:
The possible commands are specified in "commands.txt".
"""
# list of user commands
inputs_list = gym.inputs_list
if not inputs_list:
return
# execute the last command
c = inputs_list[-1]
if c == "s":
save_agent(gym.agent, gym.args.save_filename)
elif c == "q":
gym.agent.dino.quit()
elif c == "h":
gym.isHuman = True
elif c == "a":
gym.isHuman = False
# reset the list of user commands
gym.inputs_list = []
# launch a new thread
threading.Thread(target=input_thread, args=(gym.inputs_list,)).start()
def load_highscore(highscore_filename):
"""Load the highscore stored in a text file.
Args:
'highscore_filename' (str): filename of the highscore text file
Return:
'highscore' (tuple of int, (human, AI)): the best score achieved by a human and an AI
"""
human_score, ai_score = -1, -1
# try opening the highscore file
try:
highscore_file = open(highscore_filename, "r")
# if the file exists, read the score from it
for line in highscore_file.readlines():
name, score = line.split(" ")
if "human" in name:
human_score = int(score)
elif "ai" in name:
ai_score = int(score)
except FileNotFoundError:
print("No highscore file '{}' found. Creating a new one...".format(highscore_filename))
open(highscore_filename, "w")
highscore = [human_score, ai_score]
return highscore
def update_score(highscore, highscore_filename):
"""Update the highscore text file with the new highscore.
Args:
'highscore' (tuple of int, (human, AI)): the best score achieved by a human and an AI
'highscore_filename' (str): filename of the highscore text file
"""
with open(highscore_filename, "w") as highscore_file:
highscore_file.write("human {}\nai {}".format(highscore[0], highscore[1]))
def save_agent(agent, out_filename):
"""Save the agent parameters to a JSON file.
Args:
'agent' (AIAgent): AI agent to save
'out_filename' (str): name of the output file
"""
with open(out_filename, "w") as out_file:
json.dump(agent.mdp_data, out_file)
print("The AI agent has been saved to: {}".format(out_filename))
def load_agent(agent, in_filename):
"""Load the saved agent parameters from a JSON file.
Args:
'agent' (AIAgent): AI agent to load the parameters into
'in_filename' (str): name of the input file
"""
with open(in_filename, "r") as in_file:
mdp_data = json.load(in_file)
# convert all the list to np.arrays
agent.mdp_data = {
'num_states': mdp_data['num_states'],
'state_discretization': [np.array(states_list) for states_list in mdp_data['state_discretization']],
'transition_counts': np.array(mdp_data['transition_counts']),
'transition_probs': np.array(mdp_data['transition_probs']),
'reward_counts': np.array(mdp_data['reward_counts']),
'reward': np.array(mdp_data['reward']),
'value': np.array(mdp_data['value'])
}
print("The AI agent has been loaded from: {}".format(in_filename))